This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new f5b3fc7065 GH-48076: [C++][Flight] fix GeneratorStream for Tables
(#48082)
f5b3fc7065 is described below
commit f5b3fc7065f930aa7cc5a51cb963544e410d5518
Author: Dan Homola <[email protected]>
AuthorDate: Fri Nov 14 07:27:19 2025 +0100
GH-48076: [C++][Flight] fix GeneratorStream for Tables (#48082)
### Rationale for this change
After the changes in #47115, GeneratorStreams backed by anything else than
RecordBatches failed. This includes Tables and RecordBatchReaders.
This was caused by a too strict assumption that the
RecordBatchStream#GetSchemaPayload would always get called, which is not the
case when the GeneratorStream is backed by a Table or a RecordBatchReader.
### What changes are included in this PR?
Removal of the problematic assertion and initialization of the writer
object when it is needed first.
Also, to accommodate for this case, drop the incoming message when
initializing the writer in Next, as the message there
is of the SCHEMA type and we want RECORD_BATCH or DICTIONARY_BATCH one.
### Are these changes tested?
Yes, via CI. Tests for the GeneratorStreams were extended so that they test
GeneratorStreams backed by Tables and RecordBatchReaders, not just
RecordBatches.
### Are there any user-facing changes?
No, just a fix for a regression restoring the functionality from version
21.0.0 and earlier.
* GitHub Issue: #48076
Authored-by: Dan Homola <[email protected]>
Signed-off-by: David Li <[email protected]>
---
cpp/src/arrow/flight/server.cc | 27 +++++---
python/pyarrow/tests/test_flight.py | 131 +++++++++++++++++++++++++++++++++++-
2 files changed, 149 insertions(+), 9 deletions(-)
diff --git a/cpp/src/arrow/flight/server.cc b/cpp/src/arrow/flight/server.cc
index fa7f99e012..7807829b3b 100644
--- a/cpp/src/arrow/flight/server.cc
+++ b/cpp/src/arrow/flight/server.cc
@@ -285,12 +285,7 @@ class RecordBatchStream::RecordBatchStreamImpl {
Status GetSchemaPayload(FlightPayload* payload) {
if (!writer_) {
- // Create the IPC writer on first call
- auto payload_writer =
- std::make_unique<ServerRecordBatchPayloadWriter>(&payload_deque_);
- ARROW_ASSIGN_OR_RAISE(
- writer_,
ipc::internal::OpenRecordBatchWriter(std::move(payload_writer),
- reader_->schema(),
options_));
+ RETURN_NOT_OK(InitializeWriter());
}
// Return the expected schema payload.
@@ -317,8 +312,15 @@ class RecordBatchStream::RecordBatchStreamImpl {
return Status::OK();
}
if (!writer_) {
- return Status::UnknownError(
- "Writer should be initialized before reading Next batches");
+ RETURN_NOT_OK(InitializeWriter());
+ // If the writer has not been initialized yet, the first batch in the
payload
+ // queue is going to be a SCHEMA one. In this context, that is
+ // unexpected, so drop it from the queue so that there is a
RECORD_BATCH
+ // message on the top (same as would be if the writer had been
initialized
+ // in GetSchemaPayload).
+ if (payload_deque_.front().ipc_message.type ==
ipc::MessageType::SCHEMA) {
+ payload_deque_.pop_front();
+ }
}
// One WriteRecordBatch call might generate multiple payloads, so we
// need to collect them in a deque.
@@ -370,6 +372,15 @@ class RecordBatchStream::RecordBatchStreamImpl {
ipc::IpcWriteOptions options_;
std::unique_ptr<ipc::RecordBatchWriter> writer_;
std::deque<FlightPayload> payload_deque_;
+
+ Status InitializeWriter() {
+ auto payload_writer =
+ std::make_unique<ServerRecordBatchPayloadWriter>(&payload_deque_);
+ ARROW_ASSIGN_OR_RAISE(
+ writer_,
ipc::internal::OpenRecordBatchWriter(std::move(payload_writer),
+ reader_->schema(),
options_));
+ return Status::OK();
+ }
};
FlightMetadataWriter::~FlightMetadataWriter() = default;
diff --git a/python/pyarrow/tests/test_flight.py
b/python/pyarrow/tests/test_flight.py
index e5edc0eaa2..9e7bb31239 100644
--- a/python/pyarrow/tests/test_flight.py
+++ b/python/pyarrow/tests/test_flight.py
@@ -246,6 +246,40 @@ class EchoStreamFlightServer(EchoFlightServer):
raise NotImplementedError
+class EchoTableStreamFlightServer(EchoFlightServer):
+ """An echo server that streams the whole table."""
+
+ def do_get(self, context, ticket):
+ return flight.GeneratorStream(
+ self.last_message.schema,
+ [self.last_message])
+
+ def list_actions(self, context):
+ return []
+
+ def do_action(self, context, action):
+ if action.type == "who-am-i":
+ return [context.peer_identity(), context.peer().encode("utf-8")]
+ raise NotImplementedError
+
+
+class EchoRecordBatchReaderStreamFlightServer(EchoFlightServer):
+ """An echo server that streams the whole table as a RecordBatchReader."""
+
+ def do_get(self, context, ticket):
+ return flight.GeneratorStream(
+ self.last_message.schema,
+ [self.last_message.to_reader()])
+
+ def list_actions(self, context):
+ return []
+
+ def do_action(self, context, action):
+ if action.type == "who-am-i":
+ return [context.peer_identity(), context.peer().encode("utf-8")]
+ raise NotImplementedError
+
+
class GetInfoFlightServer(FlightServerBase):
"""A Flight server that tests GetFlightInfo."""
@@ -1362,7 +1396,7 @@ def test_flight_large_message():
assert result.equals(data)
-def test_flight_generator_stream():
+def test_flight_generator_stream_of_batches():
"""Try downloading a flight of RecordBatches in a GeneratorStream."""
data = pa.Table.from_arrays([
pa.array(range(0, 10 * 1024))
@@ -1378,6 +1412,101 @@ def test_flight_generator_stream():
assert result.equals(data)
+def test_flight_generator_stream_of_batches_with_dict():
+ """
+ Try downloading a flight of RecordBatches with dictionaries
+ in a GeneratorStream.
+ """
+ data = pa.Table.from_arrays([
+ pa.array(["foo", "bar", "baz", "foo", "foo"],
+ pa.dictionary(pa.int64(), pa.utf8())),
+ pa.array([123, 234, 345, 456, 567])
+ ], names=['a', 'b'])
+
+ with EchoRecordBatchReaderStreamFlightServer() as server, \
+ FlightClient(('localhost', server.port)) as client:
+ writer, _ = client.do_put(flight.FlightDescriptor.for_path('test'),
+ data.schema)
+ writer.write_table(data)
+ writer.close()
+ result = client.do_get(flight.Ticket(b'')).read_all()
+ assert result.equals(data)
+
+
+def test_flight_generator_stream_of_table():
+ """Try downloading a flight of Table in a GeneratorStream."""
+ data = pa.Table.from_arrays([
+ pa.array(range(0, 10 * 1024))
+ ], names=['a'])
+
+ with EchoTableStreamFlightServer() as server, \
+ FlightClient(('localhost', server.port)) as client:
+ writer, _ = client.do_put(flight.FlightDescriptor.for_path('test'),
+ data.schema)
+ writer.write_table(data)
+ writer.close()
+ result = client.do_get(flight.Ticket(b'')).read_all()
+ assert result.equals(data)
+
+
+def test_flight_generator_stream_of_table_with_dict():
+ """
+ Try downloading a flight of Table with dictionaries
+ in a GeneratorStream.
+ """
+ data = pa.Table.from_arrays([
+ pa.array(["foo", "bar", "baz", "foo", "foo"],
+ pa.dictionary(pa.int64(), pa.utf8())),
+ pa.array([123, 234, 345, 456, 567])
+ ], names=['a', 'b'])
+
+ with EchoRecordBatchReaderStreamFlightServer() as server, \
+ FlightClient(('localhost', server.port)) as client:
+ writer, _ = client.do_put(flight.FlightDescriptor.for_path('test'),
+ data.schema)
+ writer.write_table(data)
+ writer.close()
+ result = client.do_get(flight.Ticket(b'')).read_all()
+ assert result.equals(data)
+
+
+def test_flight_generator_stream_of_record_batch_reader():
+ """Try downloading a flight of RecordBatchReader in a GeneratorStream."""
+ data = pa.Table.from_arrays([
+ pa.array(range(0, 10 * 1024))
+ ], names=['a'])
+
+ with EchoRecordBatchReaderStreamFlightServer() as server, \
+ FlightClient(('localhost', server.port)) as client:
+ writer, _ = client.do_put(flight.FlightDescriptor.for_path('test'),
+ data.schema)
+ writer.write_table(data)
+ writer.close()
+ result = client.do_get(flight.Ticket(b'')).read_all()
+ assert result.equals(data)
+
+
+def test_flight_generator_stream_of_record_batch_reader_with_dict():
+ """
+ Try downloading a flight of RecordBatchReader with dictionaries
+ in a GeneratorStream.
+ """
+ data = pa.Table.from_arrays([
+ pa.array(["foo", "bar", "baz", "foo", "foo"],
+ pa.dictionary(pa.int64(), pa.utf8())),
+ pa.array([123, 234, 345, 456, 567])
+ ], names=['a', 'b'])
+
+ with EchoRecordBatchReaderStreamFlightServer() as server, \
+ FlightClient(('localhost', server.port)) as client:
+ writer, _ = client.do_put(flight.FlightDescriptor.for_path('test'),
+ data.schema)
+ writer.write_table(data)
+ writer.close()
+ result = client.do_get(flight.Ticket(b'')).read_all()
+ assert result.equals(data)
+
+
def test_flight_invalid_generator_stream():
"""Try streaming data with mismatched schemas."""
with InvalidStreamFlightServer() as server, \