This is an automated email from the ASF dual-hosted git repository.
zeroshade pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-go.git
The following commit(s) were added to refs/heads/main by this push:
new 2d0962ed fix(flight): make StreamChunksFromReader ctx aware and
cancellation-safe (#615)
2d0962ed is described below
commit 2d0962ed55074f050193b82aaf80f8b5995a2ffc
Author: Arnold Wakim <[email protected]>
AuthorDate: Mon Dec 29 16:44:46 2025 +0100
fix(flight): make StreamChunksFromReader ctx aware and cancellation-safe
(#615)
### Rationale for this change
`StreamChunksFromReader` previously did not observe context
cancellation. As a result, if a client disconnected early, the reader
could continue producing data indefinitely, potentially blocking on
channel sends, leaking `RecordBatch` objects, leaking the reader, and
consuming unbounded memory and CPU (this observation triggered this PR).
This fix ensures that data streaming promptly stops when the client
disconnects.
### What changes are included in this PR?
- `StreamChunksFromReader` now accepts a `context.Context`.
- Tiny change was made to `DoGet`, to ensure it continues to work with
the context-aware `StreamChunksFromReader`.
### Are these changes tested?
- To be removed from description: the tests are bit tricky to write,
similar to that of #437. Maybe @zeroshade has suggestions?
### Are there any user-facing changes?
- `StreamChunksFromReader` now accepts a `context.Context`.
---------
Co-authored-by: awakim <[email protected]>
---
arrow/flight/flight_test.go | 250 ++++++++++++++++++++++++
arrow/flight/flightsql/example/sqlite_server.go | 4 +-
arrow/flight/flightsql/server.go | 31 +--
arrow/flight/record_batch_reader.go | 23 ++-
4 files changed, 289 insertions(+), 19 deletions(-)
diff --git a/arrow/flight/flight_test.go b/arrow/flight/flight_test.go
index 98d1734c..8d75aac2 100644
--- a/arrow/flight/flight_test.go
+++ b/arrow/flight/flight_test.go
@@ -21,6 +21,8 @@ import (
"errors"
"fmt"
"io"
+ "sync"
+ "sync/atomic"
"testing"
"github.com/apache/arrow-go/v18/arrow"
@@ -484,3 +486,251 @@ type flightStreamWriter struct{}
func (f *flightStreamWriter) Send(data *flight.FlightData) error { return nil }
var _ flight.DataStreamWriter = (*flightStreamWriter)(nil)
+
+// callbackRecordReader wraps a record reader and invokes a callback on each
Next() call.
+// It tracks whether batches are properly released and the reader itself is
released.
+type callbackRecordReader struct {
+ mem memory.Allocator
+ schema *arrow.Schema
+ numBatches int
+ currentBatch atomic.Int32
+ onNext func(batchIndex int) // callback invoked before
returning from Next()
+ released atomic.Bool
+ batchesCreated atomic.Int32
+ totalRetains atomic.Int32
+ totalReleases atomic.Int32
+ createdBatches []arrow.RecordBatch // track all created batches for
cleanup
+ mu sync.Mutex
+}
+
+func newCallbackRecordReader(mem memory.Allocator, schema *arrow.Schema,
numBatches int, onNext func(int)) *callbackRecordReader {
+ return &callbackRecordReader{
+ mem: mem,
+ schema: schema,
+ numBatches: numBatches,
+ onNext: onNext,
+ }
+}
+
+func (r *callbackRecordReader) Schema() *arrow.Schema {
+ return r.schema
+}
+
+func (r *callbackRecordReader) Next() bool {
+ current := r.currentBatch.Load()
+ if int(current) >= r.numBatches {
+ return false
+ }
+ r.currentBatch.Add(1)
+
+ if r.onNext != nil {
+ r.onNext(int(current))
+ }
+
+ return true
+}
+
+func (r *callbackRecordReader) RecordBatch() arrow.RecordBatch {
+ bldr := array.NewInt64Builder(r.mem)
+ defer bldr.Release()
+
+ currentBatch := r.currentBatch.Load()
+ bldr.AppendValues([]int64{int64(currentBatch)}, nil)
+ arr := bldr.NewArray()
+
+ rec := array.NewRecordBatch(r.schema, []arrow.Array{arr}, 1)
+ arr.Release()
+
+ tracked := &trackedRecordBatch{
+ RecordBatch: rec,
+ onRetain: func() {
+ r.totalRetains.Add(1)
+ },
+ onRelease: func() {
+ r.totalReleases.Add(1)
+ },
+ }
+
+ r.mu.Lock()
+ r.createdBatches = append(r.createdBatches, tracked)
+ r.mu.Unlock()
+
+ r.batchesCreated.Add(1)
+ return tracked
+}
+
+func (r *callbackRecordReader) ReleaseAll() {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ for _, batch := range r.createdBatches {
+ batch.Release()
+ }
+ r.createdBatches = nil
+}
+
+func (r *callbackRecordReader) Retain() {}
+
+func (r *callbackRecordReader) Release() {
+ r.released.Store(true)
+}
+
+func (r *callbackRecordReader) Record() arrow.RecordBatch {
+ return r.RecordBatch()
+}
+
+func (r *callbackRecordReader) Err() error {
+ return nil
+}
+
+// trackedRecordBatch wraps a RecordBatch to track Retain/Release calls.
+type trackedRecordBatch struct {
+ arrow.RecordBatch
+ onRetain func()
+ onRelease func()
+}
+
+func (t *trackedRecordBatch) Retain() {
+ if t.onRetain != nil {
+ t.onRetain()
+ }
+ t.RecordBatch.Retain()
+}
+
+func (t *trackedRecordBatch) Release() {
+ if t.onRelease != nil {
+ t.onRelease()
+ }
+ t.RecordBatch.Release()
+}
+
+func TestStreamChunksFromReader_OK(t *testing.T) {
+ mem := memory.NewCheckedAllocator(memory.DefaultAllocator)
+ defer mem.AssertSize(t, 0)
+
+ schema := arrow.NewSchema([]arrow.Field{{Name: "value", Type:
arrow.PrimitiveTypes.Int64}}, nil)
+
+ rdr := newCallbackRecordReader(mem, schema, 5, nil)
+ defer rdr.ReleaseAll()
+
+ ch := make(chan flight.StreamChunk, 5)
+
+ ctx := context.Background()
+
+ go flight.StreamChunksFromReader(ctx, rdr, ch)
+
+ var chunksReceived int
+ for chunk := range ch {
+ if chunk.Err != nil {
+ t.Errorf("unexpected error chunk: %v", chunk.Err)
+ continue
+ }
+ if chunk.Data != nil {
+ chunksReceived++
+ chunk.Data.Release()
+ }
+ }
+
+ require.Equal(t, 5, chunksReceived, "should receive all 5 batches")
+ require.True(t, rdr.released.Load(), "reader should be released")
+
+}
+
+// TestStreamChunksFromReader_HandlesCancellation verifies that context
cancellation
+// causes StreamChunksFromReader to exit cleanly and release the reader.
+func TestStreamChunksFromReader_HandlesCancellation(t *testing.T) {
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ mem := memory.DefaultAllocator
+ schema := arrow.NewSchema([]arrow.Field{{Name: "value", Type:
arrow.PrimitiveTypes.Int64}}, nil)
+
+ rdr := newCallbackRecordReader(mem, schema, 10, nil)
+ defer rdr.ReleaseAll()
+ ch := make(chan flight.StreamChunk) // unbuffered channel
+
+ go flight.StreamChunksFromReader(ctx, rdr, ch)
+
+ chunksReceived := 0
+ for chunk := range ch {
+ if chunk.Data != nil {
+ chunksReceived++
+ chunk.Data.Release()
+ }
+
+ // Cancel context after 2 batches (simulating server detecting
client disconnect)
+ if chunksReceived == 2 {
+ cancel()
+ }
+ }
+
+ // After canceling context, StreamChunksFromReader exits and closes the
channel.
+ // The for-range loop above exits when the channel closes.
+ // By the time we reach here, the channel is closed, which means
StreamChunksFromReader's
+ // defer stack has already executed, so the reader must be released.
+
+ require.True(t, rdr.released.Load(), "reader must be released when
context is canceled")
+
+}
+
+// TestStreamChunksFromReader_CancellationReleasesBatches verifies that
batches are
+// properly tracked and demonstrates memory leaks without cleanup, then proves
cleanup fixes it.
+func TestStreamChunksFromReader_CancellationReleasesBatches(t *testing.T) {
+ mem := memory.NewCheckedAllocator(memory.DefaultAllocator)
+
+ schema := arrow.NewSchema([]arrow.Field{{Name: "value", Type:
arrow.PrimitiveTypes.Int64}}, nil)
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ // Create reader that will produce 10 batches, but we'll cancel after 3
+ reader := newCallbackRecordReader(mem, schema, 10, func(batchIndex int)
{
+ if batchIndex == 2 {
+ cancel()
+ }
+ })
+
+ ch := make(chan flight.StreamChunk, 5)
+
+ // Start streaming
+ go flight.StreamChunksFromReader(ctx, reader, ch)
+
+ // Consume chunks until channel closes
+ var chunksReceived int
+ for chunk := range ch {
+ if chunk.Err != nil {
+ t.Errorf("unexpected error chunk: %v", chunk.Err)
+ continue
+ }
+ if chunk.Data != nil {
+ chunksReceived++
+ chunk.Data.Release()
+ }
+ }
+
+ // Verify the reader was released
+ require.True(t, reader.released.Load(), "reader should be released")
+
+ // We should have received at most 3-4 chunks (depending on timing)
+ // The important part is we didn't receive all 10
+ require.LessOrEqual(t, chunksReceived, 4, "should not receive all 10
chunks, got %d", chunksReceived)
+ require.Greater(t, chunksReceived, 0, "should receive at least 1 chunk")
+
+ // Check that Retain and Release don't balance - proving there's a leak
without manual cleanup
+ retains := reader.totalRetains.Load()
+ releases := reader.totalReleases.Load()
+ batchesCreated := reader.batchesCreated.Load()
+
+ // Each batch starts with refcount=1, then StreamChunksFromReader calls
Retain() (refcount=2)
+ // For sent batches: we call Release() (refcount=1), batch still has
initial ref
+ // For unsent batches due to cancellation: they keep refcount=1 from
creation
+ // So we expect: releases < retains + batchesCreated
+ require.Less(t, releases, retains+batchesCreated,
+ "without cleanup, releases should be less than retains+created:
retains=%d, releases=%d, created=%d",
+ retains, releases, batchesCreated)
+
+ // Now manually release all created batches to show proper cleanup
fixes the leak
+ reader.ReleaseAll()
+
+ // After cleanup, memory should be freed
+ mem.AssertSize(t, 0)
+}
diff --git a/arrow/flight/flightsql/example/sqlite_server.go
b/arrow/flight/flightsql/example/sqlite_server.go
index fc7d76a2..dca7b2d6 100644
--- a/arrow/flight/flightsql/example/sqlite_server.go
+++ b/arrow/flight/flightsql/example/sqlite_server.go
@@ -354,7 +354,7 @@ func (s *SQLiteFlightSQLServer) DoGetTables(ctx
context.Context, cmd flightsql.G
}
schema := rdr.Schema()
- go flight.StreamChunksFromReader(rdr, ch)
+ go flight.StreamChunksFromReader(ctx, rdr, ch)
return schema, ch, nil
}
@@ -485,7 +485,7 @@ func doGetQuery(ctx context.Context, mem memory.Allocator,
db dbQueryCtx, query
}
ch := make(chan flight.StreamChunk)
- go flight.StreamChunksFromReader(rdr, ch)
+ go flight.StreamChunksFromReader(ctx, rdr, ch)
return schema, ch, nil
}
diff --git a/arrow/flight/flightsql/server.go b/arrow/flight/flightsql/server.go
index d5102a27..25c89bf5 100644
--- a/arrow/flight/flightsql/server.go
+++ b/arrow/flight/flightsql/server.go
@@ -381,7 +381,7 @@ func (b *BaseServer) GetFlightInfoSqlInfo(_
context.Context, _ GetSqlInfo, desc
}
// DoGetSqlInfo returns a flight stream containing the list of sqlinfo results
-func (b *BaseServer) DoGetSqlInfo(_ context.Context, cmd GetSqlInfo)
(*arrow.Schema, <-chan flight.StreamChunk, error) {
+func (b *BaseServer) DoGetSqlInfo(ctx context.Context, cmd GetSqlInfo)
(*arrow.Schema, <-chan flight.StreamChunk, error) {
if b.Alloc == nil {
b.Alloc = memory.DefaultAllocator
}
@@ -430,7 +430,7 @@ func (b *BaseServer) DoGetSqlInfo(_ context.Context, cmd
GetSqlInfo) (*arrow.Sch
}
// StreamChunksFromReader will call release on the reader when done
- go flight.StreamChunksFromReader(rdr, ch)
+ go flight.StreamChunksFromReader(ctx, rdr, ch)
return schema_ref.SqlInfo, ch, nil
}
@@ -927,19 +927,24 @@ func (f *flightSqlServer) DoGet(request *flight.Ticket,
stream flight.FlightServ
wr := flight.NewRecordWriter(stream, ipc.WithSchema(sc))
defer wr.Close()
- for chunk := range cc {
- if chunk.Err != nil {
- return chunk.Err
- }
-
- wr.SetFlightDescriptor(chunk.Desc)
- if err = wr.WriteWithAppMetadata(chunk.Data,
chunk.AppMetadata); err != nil {
- return err
+ for {
+ select {
+ case <-stream.Context().Done():
+ return stream.Context().Err()
+ case chunk, ok := <-cc:
+ if !ok {
+ return nil
+ }
+ if chunk.Err != nil {
+ return chunk.Err
+ }
+ wr.SetFlightDescriptor(chunk.Desc)
+ if err := wr.WriteWithAppMetadata(chunk.Data,
chunk.AppMetadata); err != nil {
+ return err
+ }
+ chunk.Data.Release()
}
- chunk.Data.Release()
}
-
- return err
}
type putMetadataWriter struct {
diff --git a/arrow/flight/record_batch_reader.go
b/arrow/flight/record_batch_reader.go
index 7b744075..e6990a57 100644
--- a/arrow/flight/record_batch_reader.go
+++ b/arrow/flight/record_batch_reader.go
@@ -18,6 +18,7 @@ package flight
import (
"bytes"
+ "context"
"errors"
"fmt"
"io"
@@ -212,24 +213,38 @@ type haserr interface {
// StreamChunksFromReader is a convenience function to populate a channel
// from a record reader. It is intended to be run using a separate goroutine
-// by calling `go flight.StreamChunksFromReader(rdr, ch)`.
+// by calling `go flight.StreamChunksFromReader(ctx, rdr, ch)`.
//
// If the record reader panics, an error chunk will get sent on the channel.
//
// This will close the channel and release the reader when it completes.
-func StreamChunksFromReader(rdr array.RecordReader, ch chan<- StreamChunk) {
+func StreamChunksFromReader(ctx context.Context, rdr array.RecordReader, ch
chan<- StreamChunk) {
defer close(ch)
defer func() {
if err := recover(); err != nil {
- ch <- StreamChunk{Err:
utils.FormatRecoveredError("panic while reading", err)}
+ select {
+ case ch <- StreamChunk{Err:
utils.FormatRecoveredError("panic while reading", err)}:
+ case <-ctx.Done():
+ }
}
}()
defer rdr.Release()
for rdr.Next() {
+ select {
+ case <-ctx.Done():
+ return
+ default:
+ }
+
rec := rdr.RecordBatch()
rec.Retain()
- ch <- StreamChunk{Data: rec}
+ select {
+ case ch <- StreamChunk{Data: rec}:
+ case <-ctx.Done():
+ rec.Release()
+ return
+ }
}
if e, ok := rdr.(haserr); ok {