This is an automated email from the ASF dual-hosted git repository.

zeroshade pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-go.git


The following commit(s) were added to refs/heads/main by this push:
     new 1a18113  feat(arrow/ipc): add functions to generate payloads (#196)
1a18113 is described below

commit 1a181136abb25c2e815709a6e406ed99309c10e8
Author: Matt Topol <[email protected]>
AuthorDate: Tue Nov 26 16:59:53 2024 -0500

    feat(arrow/ipc): add functions to generate payloads (#196)
    
    ### Rationale for this change
    Brought up in https://github.com/apache/arrow/issues/39730 requesting
    separate functions to generate the IPC Payloads for record batches (we
    have one for schemas in the Flight package, but it makes sense to have
    one in ipc which is more efficient).
    
    ### What changes are included in this PR?
    Two functions are added, `GetRecordBatchPayload` and `GetSchemaPayload`
    which return ipc payload objects.
    
    ### Are these changes tested?
    Yes, a unit test is added.
    
    ### Are there any user-facing changes?
    No.
---
 arrow/ipc/writer.go      | 49 +++++++++++++++++++++---------
 arrow/ipc/writer_test.go | 78 ++++++++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 113 insertions(+), 14 deletions(-)

diff --git a/arrow/ipc/writer.go b/arrow/ipc/writer.go
index 7ff4267..96f082f 100644
--- a/arrow/ipc/writer.go
+++ b/arrow/ipc/writer.go
@@ -603,7 +603,7 @@ func (w *recordEncoder) visit(p *Payload, arr arrow.Array) 
error {
                        // non-zero offset: slice the buffer
                        offset := int64(data.Offset()) * typeWidth
                        // send padding if available
-                       len := minI64(bitutil.CeilByte64(arrLen*typeWidth), 
int64(values.Len())-offset)
+                       len := min(bitutil.CeilByte64(arrLen*typeWidth), 
int64(values.Len())-offset)
                        values = memory.NewBufferBytes(values.Bytes()[offset : 
offset+len])
                default:
                        if values != nil {
@@ -628,7 +628,7 @@ func (w *recordEncoder) visit(p *Payload, arr arrow.Array) 
error {
                        // slice data buffer to include the range we need now.
                        var (
                                beg int64 = 0
-                               len       = minI64(paddedLength(totalDataBytes, 
kArrowAlignment), int64(totalDataBytes))
+                               len       = min(paddedLength(totalDataBytes, 
kArrowAlignment), int64(totalDataBytes))
                        )
                        if arr.Len() > 0 {
                                beg = arr.ValueOffset64(0)
@@ -655,7 +655,7 @@ func (w *recordEncoder) visit(p *Payload, arr arrow.Array) 
error {
                        // non-zero offset: slice the buffer
                        offset := data.Offset() * int(typeWidth)
                        // send padding if available
-                       len := int(minI64(bitutil.CeilByte64(arrLen*typeWidth), 
int64(values.Len()-offset)))
+                       len := int(min(bitutil.CeilByte64(arrLen*typeWidth), 
int64(values.Len()-offset)))
                        values = memory.SliceBuffer(values, offset, len)
                default:
                        if values != nil {
@@ -1028,7 +1028,7 @@ func (w *recordEncoder) rebaseDenseUnionValueOffsets(arr 
*array.DenseUnion, offs
                } else {
                        shiftedOffsets[i] = unshiftedOffsets[i] - offsets[c]
                }
-               lengths[c] = maxI32(lengths[c], shiftedOffsets[i]+1)
+               lengths[c] = max(lengths[c], shiftedOffsets[i]+1)
        }
        return shiftedOffsetsBuf
 }
@@ -1071,7 +1071,7 @@ func getTruncatedBuffer(offset, length int64, byteWidth 
int32, buf *memory.Buffe
 
        paddedLen := paddedLength(length*int64(byteWidth), kArrowAlignment)
        if offset != 0 || paddedLen < int64(buf.Len()) {
-               return memory.SliceBuffer(buf, int(offset*int64(byteWidth)), 
int(minI64(paddedLen, int64(buf.Len()))))
+               return memory.SliceBuffer(buf, int(offset*int64(byteWidth)), 
int(min(paddedLen, int64(buf.Len()))))
        }
        buf.Retain()
        return buf
@@ -1084,16 +1084,37 @@ func needTruncate(offset int64, buf *memory.Buffer, 
minLength int64) bool {
        return offset != 0 || minLength < int64(buf.Len())
 }
 
-func minI64(a, b int64) int64 {
-       if a < b {
-               return a
+// GetRecordBatchPayload produces the ipc payload for a given record batch.
+// The resulting payload itself must be released by the caller via the Release
+// method after it is no longer needed.
+func GetRecordBatchPayload(batch arrow.Record, opts ...Option) (Payload, 
error) {
+       cfg := newConfig(opts...)
+       var (
+               data = Payload{msg: MessageRecordBatch}
+               enc  = newRecordEncoder(
+                       cfg.alloc,
+                       0,
+                       kMaxNestingDepth,
+                       true,
+                       cfg.codec,
+                       cfg.compressNP,
+                       cfg.minSpaceSavings,
+                       make([]compressor, cfg.compressNP),
+               )
+       )
+
+       err := enc.Encode(&data, batch)
+       if err != nil {
+               return Payload{}, err
        }
-       return b
+
+       return data, nil
 }
 
-func maxI32(a, b int32) int32 {
-       if a > b {
-               return a
-       }
-       return b
+// GetSchemaPayload produces the ipc payload for a given schema.
+func GetSchemaPayload(schema *arrow.Schema, mem memory.Allocator) Payload {
+       var mapper dictutils.Mapper
+       mapper.ImportSchema(schema)
+       ps := payloadFromSchema(schema, mem, &mapper)
+       return ps[0]
 }
diff --git a/arrow/ipc/writer_test.go b/arrow/ipc/writer_test.go
index c081e44..dfc23a7 100644
--- a/arrow/ipc/writer_test.go
+++ b/arrow/ipc/writer_test.go
@@ -20,6 +20,7 @@ import (
        "bytes"
        "encoding/binary"
        "fmt"
+       "io"
        "math"
        "strings"
        "testing"
@@ -254,3 +255,80 @@ func TestWriterInferSchema(t *testing.T) {
 
        require.True(t, r.Schema().Equal(rec.Schema()))
 }
+
+type testMsgReader struct {
+       messages []*Message
+
+       curmsg *Message
+}
+
+func (r *testMsgReader) Message() (*Message, error) {
+       if r.curmsg != nil {
+               r.curmsg.Release()
+               r.curmsg = nil
+       }
+
+       if len(r.messages) == 0 {
+               return nil, io.EOF
+       }
+
+       r.curmsg = r.messages[0]
+       r.messages = r.messages[1:]
+       return r.curmsg, nil
+}
+
+func (r *testMsgReader) Release() {
+       if r.curmsg != nil {
+               r.curmsg.Release()
+               r.curmsg = nil
+       }
+       for _, m := range r.messages {
+               m.Release()
+       }
+       r.messages = nil
+}
+
+func (r *testMsgReader) Retain() {}
+
+func TestGetPayloads(t *testing.T) {
+       mem := memory.NewCheckedAllocator(memory.DefaultAllocator)
+       defer mem.AssertSize(t, 0)
+
+       schema := arrow.NewSchema([]arrow.Field{
+               {Name: "s", Type: arrow.BinaryTypes.String},
+       }, nil)
+
+       b := array.NewRecordBuilder(mem, schema)
+       defer b.Release()
+
+       b.Field(0).(*array.StringBuilder).AppendValues([]string{"foo", "bar", 
"baz"}, nil)
+       rec := b.NewRecord()
+       defer rec.Release()
+
+       schemaPayload := GetSchemaPayload(rec.Schema(), mem)
+       defer schemaPayload.Release()
+       dataPayload, err := GetRecordBatchPayload(rec, WithAllocator(mem))
+       require.NoError(t, err)
+       defer dataPayload.Release()
+
+       var schemaBuf, dataBuf bytes.Buffer
+       schemaPayload.SerializeBody(&schemaBuf)
+       dataPayload.SerializeBody(&dataBuf)
+
+       msgrdr := &testMsgReader{
+               messages: []*Message{
+                       NewMessage(schemaPayload.meta, 
memory.NewBufferBytes(schemaBuf.Bytes())),
+                       NewMessage(dataPayload.meta, 
memory.NewBufferBytes(dataBuf.Bytes())),
+               },
+       }
+
+       rdr, err := NewReaderFromMessageReader(msgrdr, WithAllocator(mem))
+       require.NoError(t, err)
+       defer rdr.Release()
+
+       assert.Truef(t, rdr.Schema().Equal(rec.Schema()), "expected: %s\ngot: 
%s", rec.Schema(), rdr.Schema())
+       got, err := rdr.Read()
+       require.NoError(t, err)
+
+       assert.Truef(t, array.RecordEqual(rec, got), "expected: %s\ngot: %s", 
rec, got)
+}

Reply via email to