This is an automated email from the ASF dual-hosted git repository.

zeroshade pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-go.git


The following commit(s) were added to refs/heads/main by this push:
     new 413bd36  feat(internal/encoding): add Discard method to decoders (#280)
413bd36 is described below

commit 413bd36a3040ef1745b7842622b2de7abda9cbc3
Author: Matt Topol <[email protected]>
AuthorDate: Mon Feb 17 10:55:20 2025 -0500

    feat(internal/encoding): add Discard method to decoders (#280)
    
    ### Rationale for this change
    Building towards proper support for Skipping rows to address #278 we
    need to be able to efficiently discard values from decoders rather than
    actively having to allocate a buffer and read into it. This makes the
    skipping/seeking within a page more efficient.
    
    ### What changes are included in this PR?
    Adding a new `Discard` method to the `Decoder` interface along with
    implementations for all the various decoders
    
    ### Are these changes tested?
    Yes, tests are added for the various decoders.
    
    ### Are there any user-facing changes?
    No, this is in an internal package.
---
 parquet/internal/encoding/boolean_decoder.go       | 29 ++++++++
 parquet/internal/encoding/byte_array_decoder.go    | 26 +++++++
 parquet/internal/encoding/byte_stream_split.go     | 13 +++-
 parquet/internal/encoding/decoder.go               |  6 ++
 parquet/internal/encoding/delta_bit_packing.go     | 44 +++++++++++
 parquet/internal/encoding/delta_byte_array.go      | 39 ++++++++++
 .../internal/encoding/delta_length_byte_array.go   | 10 +++
 parquet/internal/encoding/encoding_test.go         | 37 ++++++++++
 .../encoding/fixed_len_byte_array_decoder.go       | 29 +++++++-
 .../internal/encoding/plain_encoder_types.gen.go   | 60 +++++++++++++++
 .../encoding/plain_encoder_types.gen.go.tmpl       | 12 +++
 parquet/internal/encoding/typed_encoder.gen.go     | 85 ++++++++++++++++++++++
 .../internal/encoding/typed_encoder.gen.go.tmpl    | 12 +++
 parquet/internal/encoding/types.go                 |  4 +
 parquet/internal/utils/bit_reader.go               | 29 ++++++++
 parquet/internal/utils/rle.go                      | 27 +++++++
 16 files changed, 458 insertions(+), 4 deletions(-)

diff --git a/parquet/internal/encoding/boolean_decoder.go 
b/parquet/internal/encoding/boolean_decoder.go
index ae67f84..9a18f0a 100644
--- a/parquet/internal/encoding/boolean_decoder.go
+++ b/parquet/internal/encoding/boolean_decoder.go
@@ -50,6 +50,27 @@ func (dec *PlainBooleanDecoder) SetData(nvals int, data 
[]byte) error {
        return nil
 }
 
+func (dec *PlainBooleanDecoder) Discard(n int) (int, error) {
+       n = min(n, dec.nvals)
+       dec.nvals -= n
+
+       if dec.bitOffset+n < 8 {
+               dec.bitOffset += n
+               return n, nil
+       }
+
+       remaining := n - (8 - dec.bitOffset)
+       dec.bitOffset = 0
+       dec.data = dec.data[1:]
+
+       bytesToSkip := bitutil.BytesForBits(int64(remaining/8) * 8)
+       dec.data = dec.data[bytesToSkip:]
+       remaining -= int(bytesToSkip * 8)
+
+       dec.bitOffset += remaining
+       return n, nil
+}
+
 // Decode fills out with bools decoded from the data at the current point
 // or until we reach the end of the data.
 //
@@ -147,6 +168,14 @@ func (dec *RleBooleanDecoder) SetData(nvals int, data 
[]byte) error {
        return nil
 }
 
+func (dec *RleBooleanDecoder) Discard(n int) (int, error) {
+       n = min(n, dec.nvals)
+
+       n = dec.rleDec.Discard(n)
+       dec.nvals -= n
+       return n, nil
+}
+
 func (dec *RleBooleanDecoder) Decode(out []bool) (int, error) {
        max := shared_utils.Min(len(out), dec.nvals)
 
diff --git a/parquet/internal/encoding/byte_array_decoder.go 
b/parquet/internal/encoding/byte_array_decoder.go
index dc6c27d..1253cce 100644
--- a/parquet/internal/encoding/byte_array_decoder.go
+++ b/parquet/internal/encoding/byte_array_decoder.go
@@ -18,6 +18,7 @@ package encoding
 
 import (
        "encoding/binary"
+       "errors"
 
        "github.com/apache/arrow-go/v18/arrow"
        "github.com/apache/arrow-go/v18/arrow/array"
@@ -44,6 +45,31 @@ func (PlainByteArrayDecoder) Type() parquet.Type {
        return parquet.Types.ByteArray
 }
 
+func (pbad *PlainByteArrayDecoder) Discard(n int) (int, error) {
+       n = min(n, pbad.nvals)
+       // we have to skip the length of each value by first checking
+       // the length of the value and then skipping that many bytes
+       for i := 0; i < n; i++ {
+               if len(pbad.data) < 4 {
+                       return i, errors.New("parquet: eof skipping bytearray 
values")
+               }
+
+               valueLen := int32(binary.LittleEndian.Uint32(pbad.data[:4]))
+               if valueLen < 0 {
+                       return i, errors.New("parquet: invalid BYTE_ARRAY 
value")
+               }
+
+               if int64(len(pbad.data)) < int64(valueLen)+4 {
+                       return i, errors.New("parquet: eof skipping bytearray 
values")
+               }
+
+               pbad.data = pbad.data[valueLen+4:]
+       }
+
+       pbad.nvals -= n
+       return n, nil
+}
+
 // Decode will populate the slice of bytearrays in full or until the number
 // of values is consumed.
 //
diff --git a/parquet/internal/encoding/byte_stream_split.go 
b/parquet/internal/encoding/byte_stream_split.go
index 39c1209..e65a972 100644
--- a/parquet/internal/encoding/byte_stream_split.go
+++ b/parquet/internal/encoding/byte_stream_split.go
@@ -92,8 +92,8 @@ func encodeByteStreamSplitWidth8(data []byte, in []byte) {
 // into the output buffer 'out' using BYTE_STREAM_SPLIT encoding.
 // 'out' must have space for at least len(data) bytes.
 func decodeByteStreamSplitBatchWidth4(data []byte, nValues, stride int, out 
[]byte) {
-       debug.Assert(len(out) >= len(data), fmt.Sprintf("not enough space in 
output buffer for decoding, out: %d bytes, data: %d bytes", len(out), 
len(data)))
        const width = 4
+       debug.Assert(len(out) >= nValues*width, fmt.Sprintf("not enough space 
in output buffer for decoding, out: %d bytes, data: %d bytes", len(out), 
len(data)))
        for element := 0; element < nValues; element++ {
                out[width*element] = data[element]
                out[width*element+1] = data[stride+element]
@@ -106,8 +106,8 @@ func decodeByteStreamSplitBatchWidth4(data []byte, nValues, 
stride int, out []by
 // into the output buffer 'out' using BYTE_STREAM_SPLIT encoding.
 // 'out' must have space for at least len(data) bytes.
 func decodeByteStreamSplitBatchWidth8(data []byte, nValues, stride int, out 
[]byte) {
-       debug.Assert(len(out) >= len(data), fmt.Sprintf("not enough space in 
output buffer for decoding, out: %d bytes, data: %d bytes", len(out), 
len(data)))
        const width = 8
+       debug.Assert(len(out) >= nValues*width, fmt.Sprintf("not enough space 
in output buffer for decoding, out: %d bytes, data: %d bytes", len(out), 
len(data)))
        for element := 0; element < nValues; element++ {
                out[width*element] = data[element]
                out[width*element+1] = data[stride+element]
@@ -351,9 +351,16 @@ func (dec *ByteStreamSplitDecoder[T]) SetData(nvals int, 
data []byte) error {
        return dec.decoder.SetData(nvals, data)
 }
 
+func (dec *ByteStreamSplitDecoder[T]) Discard(n int) (int, error) {
+       n = min(n, dec.nvals)
+       dec.nvals -= n
+       dec.data = dec.data[n:]
+       return n, nil
+}
+
 func (dec *ByteStreamSplitDecoder[T]) Decode(out []T) (int, error) {
        typeLen := dec.Type().ByteSize()
-       toRead := len(out)
+       toRead := min(len(out), dec.nvals)
        numBytesNeeded := toRead * typeLen
        if numBytesNeeded > len(dec.data) || numBytesNeeded > math.MaxInt32 {
                return 0, xerrors.New("parquet: eof exception")
diff --git a/parquet/internal/encoding/decoder.go 
b/parquet/internal/encoding/decoder.go
index 50bd236..64455d1 100644
--- a/parquet/internal/encoding/decoder.go
+++ b/parquet/internal/encoding/decoder.go
@@ -142,6 +142,12 @@ func (d *dictDecoder) SetData(nvals int, data []byte) 
error {
        return nil
 }
 
+func (d *dictDecoder) discard(n int) (int, error) {
+       n = d.idxDecoder.Discard(n)
+       d.nvals -= n
+       return n, nil
+}
+
 func (d *dictDecoder) decode(out interface{}) (int, error) {
        n, err := d.idxDecoder.GetBatchWithDict(d.dictValueDecoder, out)
        d.nvals -= n
diff --git a/parquet/internal/encoding/delta_bit_packing.go 
b/parquet/internal/encoding/delta_bit_packing.go
index 0ead443..a57b12c 100644
--- a/parquet/internal/encoding/delta_bit_packing.go
+++ b/parquet/internal/encoding/delta_bit_packing.go
@@ -150,6 +150,50 @@ func (d *deltaBitPackDecoder[T]) unpackNextMini() error {
        return nil
 }
 
+func (d *deltaBitPackDecoder[T]) Discard(n int) (int, error) {
+       n = min(n, int(d.nvals))
+       if n == 0 {
+               return 0, nil
+       }
+
+       var (
+               err       error
+               remaining = n
+       )
+
+       if !d.usedFirst {
+               d.usedFirst = true
+               remaining--
+       }
+
+       for remaining > 0 {
+               if d.currentBlockVals == 0 {
+                       if err = d.initBlock(); err != nil {
+                               return n - remaining, err
+                       }
+               }
+
+               if d.currentMiniBlockVals == 0 {
+                       if err = d.unpackNextMini(); err != nil {
+                               return n - remaining, err
+                       }
+               }
+
+               start := d.valsPerMini - d.currentMiniBlockVals
+               numToDiscard := len(d.miniBlockValues[start:])
+               if numToDiscard > remaining {
+                       numToDiscard = remaining
+               }
+
+               d.currentBlockVals -= uint32(numToDiscard)
+               d.currentMiniBlockVals -= uint32(numToDiscard)
+               remaining -= numToDiscard
+       }
+
+       d.nvals -= n
+       return n, nil
+}
+
 // Decode retrieves min(remaining values, len(out)) values from the data and 
returns the number
 // of values actually decoded and any errors encountered.
 func (d *deltaBitPackDecoder[T]) Decode(out []T) (int, error) {
diff --git a/parquet/internal/encoding/delta_byte_array.go 
b/parquet/internal/encoding/delta_byte_array.go
index f86bb58..bb2134a 100644
--- a/parquet/internal/encoding/delta_byte_array.go
+++ b/parquet/internal/encoding/delta_byte_array.go
@@ -181,6 +181,45 @@ func (d *DeltaByteArrayDecoder) SetData(nvalues int, data 
[]byte) error {
        return d.DeltaLengthByteArrayDecoder.SetData(nvalues, 
data[int(prefixLenDec.bytesRead()):])
 }
 
+func (d *DeltaByteArrayDecoder) Discard(n int) (int, error) {
+       n = min(n, d.nvals)
+       if n == 0 {
+               return 0, nil
+       }
+
+       remaining := n
+       tmp := make([]parquet.ByteArray, 1)
+       if d.lastVal == nil {
+               if _, err := d.DeltaLengthByteArrayDecoder.Decode(tmp); err != 
nil {
+                       return 0, err
+               }
+               d.lastVal = tmp[0]
+               d.prefixLengths = d.prefixLengths[1:]
+               remaining--
+       }
+
+       var prefixLen int32
+       for remaining > 0 {
+               prefixLen, d.prefixLengths = d.prefixLengths[0], 
d.prefixLengths[1:]
+               prefix := d.lastVal[:prefixLen:prefixLen]
+
+               if _, err := d.DeltaLengthByteArrayDecoder.Decode(tmp); err != 
nil {
+                       return n - remaining, err
+               }
+
+               if len(tmp[0]) == 0 {
+                       d.lastVal = prefix
+               } else {
+                       d.lastVal = make([]byte, int(prefixLen)+len(tmp[0]))
+                       copy(d.lastVal, prefix)
+                       copy(d.lastVal[prefixLen:], tmp[0])
+               }
+               remaining--
+       }
+
+       return n, nil
+}
+
 // Decode decodes byte arrays into the slice provided and returns the number 
of values actually decoded
 func (d *DeltaByteArrayDecoder) Decode(out []parquet.ByteArray) (int, error) {
        max := utils.Min(len(out), d.nvals)
diff --git a/parquet/internal/encoding/delta_length_byte_array.go 
b/parquet/internal/encoding/delta_length_byte_array.go
index 6d3b570..eab064c 100644
--- a/parquet/internal/encoding/delta_length_byte_array.go
+++ b/parquet/internal/encoding/delta_length_byte_array.go
@@ -123,6 +123,16 @@ func (d *DeltaLengthByteArrayDecoder) SetData(nvalues int, 
data []byte) error {
        return d.decoder.SetData(nvalues, data[int(dec.bytesRead()):])
 }
 
+func (d *DeltaLengthByteArrayDecoder) Discard(n int) (int, error) {
+       n = min(n, d.nvals)
+       for i := 0; i < n; i++ {
+               d.data = d.data[d.lengths[i]:]
+       }
+       d.nvals -= n
+       d.lengths = d.lengths[n:]
+       return n, nil
+}
+
 // Decode populates the passed in slice with data decoded until it hits the 
length of out
 // or runs out of values in the column to decode, then returns the number of 
values actually decoded.
 func (d *DeltaLengthByteArrayDecoder) Decode(out []parquet.ByteArray) (int, 
error) {
diff --git a/parquet/internal/encoding/encoding_test.go 
b/parquet/internal/encoding/encoding_test.go
index 9f86882..93e830d 100644
--- a/parquet/internal/encoding/encoding_test.go
+++ b/parquet/internal/encoding/encoding_test.go
@@ -307,6 +307,22 @@ func (b *BaseEncodingTestSuite) encodeTestData(e 
parquet.Encoding) (encoding.Buf
        return enc.FlushValues()
 }
 
+func (b *BaseEncodingTestSuite) testDiscardDecodeData(e parquet.Encoding, buf 
[]byte) {
+       dec := encoding.NewDecoder(testutils.TypeToParquetType(b.typ), e, 
b.descr, b.mem)
+       b.Equal(e, dec.Encoding())
+       b.Equal(b.descr.PhysicalType(), dec.Type())
+
+       dec.SetData(b.nvalues, buf)
+       discarded := b.nvalues / 2
+       n, err := dec.Discard(discarded)
+       b.Require().NoError(err)
+       b.Equal(discarded, n)
+
+       decoded, _ := decode(dec, b.decodeBuf)
+       b.Equal(b.nvalues-discarded, decoded)
+       b.Equal(reflect.ValueOf(b.draws).Slice(discarded, 
b.nvalues).Interface(), reflect.ValueOf(b.decodeBuf).Slice(0, 
b.nvalues-discarded).Interface())
+}
+
 func (b *BaseEncodingTestSuite) decodeTestData(e parquet.Encoding, buf []byte) 
{
        dec := encoding.NewDecoder(testutils.TypeToParquetType(b.typ), e, 
b.descr, b.mem)
        b.Equal(e, dec.Encoding())
@@ -316,6 +332,16 @@ func (b *BaseEncodingTestSuite) decodeTestData(e 
parquet.Encoding, buf []byte) {
        decoded, _ := decode(dec, b.decodeBuf)
        b.Equal(b.nvalues, decoded)
        b.Equal(reflect.ValueOf(b.draws).Slice(0, b.nvalues).Interface(), 
reflect.ValueOf(b.decodeBuf).Slice(0, b.nvalues).Interface())
+
+       dec.SetData(b.nvalues, buf)
+       decoded = 0
+       for i := 0; i < b.nvalues; i += 500 {
+               n, err := decode(dec, reflect.ValueOf(b.decodeBuf).Slice(i, 
i+500).Interface())
+               b.Require().NoError(err)
+               decoded += n
+       }
+       b.Equal(b.nvalues, decoded)
+       b.Equal(reflect.ValueOf(b.draws).Slice(0, b.nvalues).Interface(), 
reflect.ValueOf(b.decodeBuf).Slice(0, b.nvalues).Interface())
 }
 
 func (b *BaseEncodingTestSuite) encodeTestDataSpaced(e parquet.Encoding, 
validBits []byte, validBitsOffset int64) (encoding.Buffer, error) {
@@ -343,6 +369,7 @@ func (b *BaseEncodingTestSuite) checkRoundTrip(e 
parquet.Encoding) {
        buf, _ := b.encodeTestData(e)
        defer buf.Release()
        b.decodeTestData(e, buf.Bytes())
+       b.testDiscardDecodeData(e, buf.Bytes())
 }
 
 func (b *BaseEncodingTestSuite) checkRoundTripSpaced(e parquet.Encoding, 
validBits []byte, validBitsOffset int64) {
@@ -541,6 +568,16 @@ func (d *DictionaryEncodingTestSuite) checkRoundTrip() {
        decoded, _ = decodeSpaced(decoder, d.decodeBuf, 0, validBits, 0)
        d.Equal(d.nvalues, decoded)
        d.Equal(reflect.ValueOf(d.draws).Slice(0, d.nvalues).Interface(), 
reflect.ValueOf(d.decodeBuf).Slice(0, d.nvalues).Interface())
+
+       decoder.SetData(d.nvalues, indices.Bytes())
+       discarded := d.nvalues / 2
+       n, err := decoder.Discard(discarded)
+       d.Require().NoError(err)
+       d.Equal(discarded, n)
+
+       decoded, _ = decode(decoder, d.decodeBuf)
+       d.Equal(d.nvalues-discarded, decoded)
+       d.Equal(reflect.ValueOf(d.draws).Slice(discarded, 
d.nvalues).Interface(), reflect.ValueOf(d.decodeBuf).Slice(0, 
d.nvalues-discarded).Interface())
 }
 
 func (d *DictionaryEncodingTestSuite) TestBasicRoundTrip() {
diff --git a/parquet/internal/encoding/fixed_len_byte_array_decoder.go 
b/parquet/internal/encoding/fixed_len_byte_array_decoder.go
index 0abd036..6e389ae 100644
--- a/parquet/internal/encoding/fixed_len_byte_array_decoder.go
+++ b/parquet/internal/encoding/fixed_len_byte_array_decoder.go
@@ -17,6 +17,7 @@
 package encoding
 
 import (
+       "errors"
        "fmt"
        "math"
 
@@ -35,6 +36,18 @@ func (PlainFixedLenByteArrayDecoder) Type() parquet.Type {
        return parquet.Types.FixedLenByteArray
 }
 
+func (pflba *PlainFixedLenByteArrayDecoder) Discard(n int) (int, error) {
+       n = min(n, pflba.nvals)
+       numBytesNeeded := n * pflba.typeLen
+       if numBytesNeeded > len(pflba.data) || numBytesNeeded > math.MaxInt32 {
+               return 0, errors.New("parquet: eof exception")
+       }
+
+       pflba.data = pflba.data[numBytesNeeded:]
+       pflba.nvals -= n
+       return n, nil
+}
+
 // Decode populates out with fixed length byte array values until either there 
are no more
 // values to decode or the length of out has been filled. Then returns the 
total number of values
 // that were decoded.
@@ -49,6 +62,8 @@ func (pflba *PlainFixedLenByteArrayDecoder) Decode(out 
[]parquet.FixedLenByteArr
                out[idx] = pflba.data[:pflba.typeLen]
                pflba.data = pflba.data[pflba.typeLen:]
        }
+
+       pflba.nvals -= max
        return max, nil
 }
 
@@ -92,8 +107,20 @@ func (dec *ByteStreamSplitFixedLenByteArrayDecoder) 
SetData(nvals int, data []by
        return dec.decoder.SetData(nvals, data)
 }
 
+func (dec *ByteStreamSplitFixedLenByteArrayDecoder) Discard(n int) (int, 
error) {
+       n = min(n, dec.nvals)
+       numBytesNeeded := n * dec.typeLen
+       if numBytesNeeded > len(dec.data) || numBytesNeeded > math.MaxInt32 {
+               return 0, errors.New("parquet: eof exception")
+       }
+
+       dec.nvals -= n
+       dec.data = dec.data[n:]
+       return n, nil
+}
+
 func (dec *ByteStreamSplitFixedLenByteArrayDecoder) Decode(out 
[]parquet.FixedLenByteArray) (int, error) {
-       toRead := len(out)
+       toRead := min(len(out), dec.nvals)
        numBytesNeeded := toRead * dec.typeLen
        if numBytesNeeded > len(dec.data) || numBytesNeeded > math.MaxInt32 {
                return 0, xerrors.New("parquet: eof exception")
diff --git a/parquet/internal/encoding/plain_encoder_types.gen.go 
b/parquet/internal/encoding/plain_encoder_types.gen.go
index 8cc187d..a499bbb 100644
--- a/parquet/internal/encoding/plain_encoder_types.gen.go
+++ b/parquet/internal/encoding/plain_encoder_types.gen.go
@@ -168,6 +168,18 @@ func (PlainInt32Decoder) Type() parquet.Type {
        return parquet.Types.Int32
 }
 
+func (dec *PlainInt32Decoder) Discard(n int) (int, error) {
+       n = min(n, dec.nvals)
+       nbytes := int64(n) * int64(arrow.Int32SizeBytes)
+       if nbytes > int64(len(dec.data)) || nbytes > math.MaxInt32 {
+               return 0, fmt.Errorf("parquet: eof exception discard plain 
Int32, nvals: %d, nbytes: %d, datalen: %d", dec.nvals, nbytes, len(dec.data))
+       }
+
+       dec.data = dec.data[nbytes:]
+       dec.nvals -= n
+       return n, nil
+}
+
 // Decode populates the given slice with values from the data to be decoded,
 // decoding the min(len(out), remaining values).
 // It returns the number of values actually decoded and any error encountered.
@@ -273,6 +285,18 @@ func (PlainInt64Decoder) Type() parquet.Type {
        return parquet.Types.Int64
 }
 
+func (dec *PlainInt64Decoder) Discard(n int) (int, error) {
+       n = min(n, dec.nvals)
+       nbytes := int64(n) * int64(arrow.Int64SizeBytes)
+       if nbytes > int64(len(dec.data)) || nbytes > math.MaxInt32 {
+               return 0, fmt.Errorf("parquet: eof exception discard plain 
Int64, nvals: %d, nbytes: %d, datalen: %d", dec.nvals, nbytes, len(dec.data))
+       }
+
+       dec.data = dec.data[nbytes:]
+       dec.nvals -= n
+       return n, nil
+}
+
 // Decode populates the given slice with values from the data to be decoded,
 // decoding the min(len(out), remaining values).
 // It returns the number of values actually decoded and any error encountered.
@@ -378,6 +402,18 @@ func (PlainInt96Decoder) Type() parquet.Type {
        return parquet.Types.Int96
 }
 
+func (dec *PlainInt96Decoder) Discard(n int) (int, error) {
+       n = min(n, dec.nvals)
+       nbytes := int64(n) * int64(parquet.Int96SizeBytes)
+       if nbytes > int64(len(dec.data)) || nbytes > math.MaxInt32 {
+               return 0, fmt.Errorf("parquet: eof exception discard plain 
Int96, nvals: %d, nbytes: %d, datalen: %d", dec.nvals, nbytes, len(dec.data))
+       }
+
+       dec.data = dec.data[nbytes:]
+       dec.nvals -= n
+       return n, nil
+}
+
 // Decode populates the given slice with values from the data to be decoded,
 // decoding the min(len(out), remaining values).
 // It returns the number of values actually decoded and any error encountered.
@@ -483,6 +519,18 @@ func (PlainFloat32Decoder) Type() parquet.Type {
        return parquet.Types.Float
 }
 
+func (dec *PlainFloat32Decoder) Discard(n int) (int, error) {
+       n = min(n, dec.nvals)
+       nbytes := int64(n) * int64(arrow.Float32SizeBytes)
+       if nbytes > int64(len(dec.data)) || nbytes > math.MaxInt32 {
+               return 0, fmt.Errorf("parquet: eof exception discard plain 
Float32, nvals: %d, nbytes: %d, datalen: %d", dec.nvals, nbytes, len(dec.data))
+       }
+
+       dec.data = dec.data[nbytes:]
+       dec.nvals -= n
+       return n, nil
+}
+
 // Decode populates the given slice with values from the data to be decoded,
 // decoding the min(len(out), remaining values).
 // It returns the number of values actually decoded and any error encountered.
@@ -588,6 +636,18 @@ func (PlainFloat64Decoder) Type() parquet.Type {
        return parquet.Types.Double
 }
 
+func (dec *PlainFloat64Decoder) Discard(n int) (int, error) {
+       n = min(n, dec.nvals)
+       nbytes := int64(n) * int64(arrow.Float64SizeBytes)
+       if nbytes > int64(len(dec.data)) || nbytes > math.MaxInt32 {
+               return 0, fmt.Errorf("parquet: eof exception discard plain 
Float64, nvals: %d, nbytes: %d, datalen: %d", dec.nvals, nbytes, len(dec.data))
+       }
+
+       dec.data = dec.data[nbytes:]
+       dec.nvals -= n
+       return n, nil
+}
+
 // Decode populates the given slice with values from the data to be decoded,
 // decoding the min(len(out), remaining values).
 // It returns the number of values actually decoded and any error encountered.
diff --git a/parquet/internal/encoding/plain_encoder_types.gen.go.tmpl 
b/parquet/internal/encoding/plain_encoder_types.gen.go.tmpl
index 71d86de..4af6edd 100644
--- a/parquet/internal/encoding/plain_encoder_types.gen.go.tmpl
+++ b/parquet/internal/encoding/plain_encoder_types.gen.go.tmpl
@@ -129,6 +129,18 @@ func (Plain{{.Name}}Decoder) Type() parquet.Type {
   return parquet.Types.{{if .physical}}{{.physical}}{{else}}{{.Name}}{{end}}
 }
 
+func (dec *Plain{{.Name}}Decoder) Discard(n int) (int, error) {
+  n = min(n, dec.nvals)
+  nbytes := int64(n) * int64({{.prefix}}.{{.Name}}SizeBytes)
+  if nbytes > int64(len(dec.data)) || nbytes > math.MaxInt32 {
+    return 0, fmt.Errorf("parquet: eof exception discard plain {{.Name}}, 
nvals: %d, nbytes: %d, datalen: %d", dec.nvals, nbytes, len(dec.data))
+  }
+
+  dec.data = dec.data[nbytes:]
+  dec.nvals -= n
+  return n, nil
+}
+
 // Decode populates the given slice with values from the data to be decoded,
 // decoding the min(len(out), remaining values).
 // It returns the number of values actually decoded and any error encountered.
diff --git a/parquet/internal/encoding/typed_encoder.gen.go 
b/parquet/internal/encoding/typed_encoder.gen.go
index 04a7933..bde680e 100644
--- a/parquet/internal/encoding/typed_encoder.gen.go
+++ b/parquet/internal/encoding/typed_encoder.gen.go
@@ -19,6 +19,7 @@
 package encoding
 
 import (
+       "errors"
        "fmt"
        "unsafe"
 
@@ -195,6 +196,18 @@ func (DictInt32Decoder) Type() parquet.Type {
        return parquet.Types.Int32
 }
 
+func (d *DictInt32Decoder) Discard(n int) (int, error) {
+       n = min(n, d.nvals)
+       discarded, err := d.discard(n)
+       if err != nil {
+               return discarded, err
+       }
+       if n != discarded {
+               return discarded, errors.New("parquet: dict eof exception")
+       }
+       return n, nil
+}
+
 // Decode populates the passed in slice with min(len(out), remaining values) 
values,
 // decoding using the dictionary to get the actual values. Returns the number 
of values
 // actually decoded and any error encountered.
@@ -436,6 +449,18 @@ func (DictInt64Decoder) Type() parquet.Type {
        return parquet.Types.Int64
 }
 
+func (d *DictInt64Decoder) Discard(n int) (int, error) {
+       n = min(n, d.nvals)
+       discarded, err := d.discard(n)
+       if err != nil {
+               return discarded, err
+       }
+       if n != discarded {
+               return discarded, errors.New("parquet: dict eof exception")
+       }
+       return n, nil
+}
+
 // Decode populates the passed in slice with min(len(out), remaining values) 
values,
 // decoding using the dictionary to get the actual values. Returns the number 
of values
 // actually decoded and any error encountered.
@@ -651,6 +676,18 @@ func (DictInt96Decoder) Type() parquet.Type {
        return parquet.Types.Int96
 }
 
+func (d *DictInt96Decoder) Discard(n int) (int, error) {
+       n = min(n, d.nvals)
+       discarded, err := d.discard(n)
+       if err != nil {
+               return discarded, err
+       }
+       if n != discarded {
+               return discarded, errors.New("parquet: dict eof exception")
+       }
+       return n, nil
+}
+
 // Decode populates the passed in slice with min(len(out), remaining values) 
values,
 // decoding using the dictionary to get the actual values. Returns the number 
of values
 // actually decoded and any error encountered.
@@ -880,6 +917,18 @@ func (DictFloat32Decoder) Type() parquet.Type {
        return parquet.Types.Float
 }
 
+func (d *DictFloat32Decoder) Discard(n int) (int, error) {
+       n = min(n, d.nvals)
+       discarded, err := d.discard(n)
+       if err != nil {
+               return discarded, err
+       }
+       if n != discarded {
+               return discarded, errors.New("parquet: dict eof exception")
+       }
+       return n, nil
+}
+
 // Decode populates the passed in slice with min(len(out), remaining values) 
values,
 // decoding using the dictionary to get the actual values. Returns the number 
of values
 // actually decoded and any error encountered.
@@ -1109,6 +1158,18 @@ func (DictFloat64Decoder) Type() parquet.Type {
        return parquet.Types.Double
 }
 
+func (d *DictFloat64Decoder) Discard(n int) (int, error) {
+       n = min(n, d.nvals)
+       discarded, err := d.discard(n)
+       if err != nil {
+               return discarded, err
+       }
+       if n != discarded {
+               return discarded, errors.New("parquet: dict eof exception")
+       }
+       return n, nil
+}
+
 // Decode populates the passed in slice with min(len(out), remaining values) 
values,
 // decoding using the dictionary to get the actual values. Returns the number 
of values
 // actually decoded and any error encountered.
@@ -1378,6 +1439,18 @@ func (DictByteArrayDecoder) Type() parquet.Type {
        return parquet.Types.ByteArray
 }
 
+func (d *DictByteArrayDecoder) Discard(n int) (int, error) {
+       n = min(n, d.nvals)
+       discarded, err := d.discard(n)
+       if err != nil {
+               return discarded, err
+       }
+       if n != discarded {
+               return discarded, errors.New("parquet: dict eof exception")
+       }
+       return n, nil
+}
+
 // Decode populates the passed in slice with min(len(out), remaining values) 
values,
 // decoding using the dictionary to get the actual values. Returns the number 
of values
 // actually decoded and any error encountered.
@@ -1561,6 +1634,18 @@ func (DictFixedLenByteArrayDecoder) Type() parquet.Type {
        return parquet.Types.FixedLenByteArray
 }
 
+func (d *DictFixedLenByteArrayDecoder) Discard(n int) (int, error) {
+       n = min(n, d.nvals)
+       discarded, err := d.discard(n)
+       if err != nil {
+               return discarded, err
+       }
+       if n != discarded {
+               return discarded, errors.New("parquet: dict eof exception")
+       }
+       return n, nil
+}
+
 // Decode populates the passed in slice with min(len(out), remaining values) 
values,
 // decoding using the dictionary to get the actual values. Returns the number 
of values
 // actually decoded and any error encountered.
diff --git a/parquet/internal/encoding/typed_encoder.gen.go.tmpl 
b/parquet/internal/encoding/typed_encoder.gen.go.tmpl
index 49ae990..e9f3ce5 100644
--- a/parquet/internal/encoding/typed_encoder.gen.go.tmpl
+++ b/parquet/internal/encoding/typed_encoder.gen.go.tmpl
@@ -276,6 +276,18 @@ func (Dict{{.Name}}Decoder) Type() parquet.Type {
   return parquet.Types.{{if .physical}}{{.physical}}{{else}}{{.Name}}{{end}}
 }
 
+func (d *Dict{{.Name}}Decoder) Discard(n int) (int, error) {
+  n = min(n, d.nvals)
+  discarded, err := d.discard(n)
+  if err != nil {
+    return discarded, err
+  }
+  if n != discarded {
+    return discarded, errors.New("parquet: dict eof exception")
+  }
+  return n, nil
+}
+
 // Decode populates the passed in slice with min(len(out), remaining values) 
values,
 // decoding using the dictionary to get the actual values. Returns the number 
of values
 // actually decoded and any error encountered.
diff --git a/parquet/internal/encoding/types.go 
b/parquet/internal/encoding/types.go
index ca94a4c..5bd9949 100644
--- a/parquet/internal/encoding/types.go
+++ b/parquet/internal/encoding/types.go
@@ -40,6 +40,10 @@ type TypedDecoder interface {
        ValuesLeft() int
        // Type returns the physical type this can decode.
        Type() parquet.Type
+       // Discard the next n values from the decoder, returning the actual 
number
+       // of values that were able to be discarded (should be equal to n 
unless an
+       // error occurs).
+       Discard(n int) (int, error)
 }
 
 // DictDecoder is a special TypedDecoder which implements dictionary decoding
diff --git a/parquet/internal/utils/bit_reader.go 
b/parquet/internal/utils/bit_reader.go
index 2b6f048..1080529 100644
--- a/parquet/internal/utils/bit_reader.go
+++ b/parquet/internal/utils/bit_reader.go
@@ -289,6 +289,35 @@ func (b *BitReader) GetBatchBools(out []bool) (int, error) 
{
        return i, nil
 }
 
+func (b *BitReader) Discard(bits uint, n int) (int, error) {
+       if bits > 64 {
+               return 0, errors.New("must be 64 bits or less per read")
+       }
+
+       i := 0
+       for ; i < n && b.bitoffset != 0; i++ {
+               if _, err := b.next(bits); err != nil {
+                       return i, err
+               }
+       }
+
+       if n-i > 32 {
+               toSkip := (n - i) / 32 * 32
+
+               bytesToSkip := bitutil.BytesForBits(int64(toSkip * int(bits)))
+               b.byteoffset += int64(bytesToSkip)
+               i += toSkip
+       }
+
+       b.fillbuffer()
+       for ; i < n; i++ {
+               if _, err := b.next(bits); err != nil {
+                       return i, err
+               }
+       }
+       return n, nil
+}
+
 // GetBatch fills out by decoding values repeated from the stream that are 
encoded
 // using bits as the number of bits per value. The values are expected to be 
bit packed
 // so we will unpack the values to populate.
diff --git a/parquet/internal/utils/rle.go b/parquet/internal/utils/rle.go
index 333e72d..12eca7b 100644
--- a/parquet/internal/utils/rle.go
+++ b/parquet/internal/utils/rle.go
@@ -178,6 +178,33 @@ func (r *RleDecoder) GetValue() (uint64, bool) {
        return vals[0], n == 1
 }
 
+func (r *RleDecoder) Discard(n int) int {
+       read := 0
+       for read < n {
+               remain := n - read
+
+               if r.repCount > 0 {
+                       repbatch := int(math.Min(float64(remain), 
float64(r.repCount)))
+                       r.repCount -= int32(repbatch)
+                       read += repbatch
+               } else if r.litCount > 0 {
+                       litbatch := int(math.Min(float64(remain), 
float64(r.litCount)))
+                       n, _ := r.r.Discard(uint(r.bitWidth), litbatch)
+                       if n != litbatch {
+                               return read
+                       }
+
+                       r.litCount -= int32(litbatch)
+                       read += litbatch
+               } else {
+                       if !r.Next() {
+                               return read
+                       }
+               }
+       }
+       return read
+}
+
 func (r *RleDecoder) GetBatch(values []uint64) int {
        read := 0
        size := len(values)

Reply via email to