This is an automated email from the ASF dual-hosted git repository.
wesm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new d90c159 ARROW-4858: [Flight/Python] enable FlightDataStream to be
implemented in Python
d90c159 is described below
commit d90c1597ce488d27aca88bcc33bfffe1e96a1ac6
Author: David Li <[email protected]>
AuthorDate: Fri Mar 15 09:59:10 2019 -0500
ARROW-4858: [Flight/Python] enable FlightDataStream to be implemented in
Python
This enables use cases like converting and streaming a large Pandas
DataFrame in chunks, or generating/processing data on the fly.
It could be simplified if we decide to only support iterators of
RecordBatches (versus right now where it accepts Tables, RecordBatches, other
FlightDataStreams, etc.)
Author: David Li <[email protected]>
Closes #3890 from lihalite/arrow-py-data and squashes the following commits:
ed108cfd5 <David Li> Document that Flight/Python helpers require GIL to
construct
0bcd8e144 <David Li> Check that PyGeneratorFlightDataStream returns
consistent schema
27298945e <David Li> Enable FlightDataStream to be implemented in Python
---
cpp/src/arrow/python/flight.cc | 16 +++++
cpp/src/arrow/python/flight.h | 26 +++++++++
python/pyarrow/_flight.pyx | 90 +++++++++++++++++++++++++++++
python/pyarrow/flight.py | 1 +
python/pyarrow/includes/libarrow_flight.pxd | 29 +++++++++-
python/pyarrow/tests/test_flight.py | 47 +++++++++++++++
6 files changed, 208 insertions(+), 1 deletion(-)
diff --git a/cpp/src/arrow/python/flight.cc b/cpp/src/arrow/python/flight.cc
index ec25d32..326245b 100644
--- a/cpp/src/arrow/python/flight.cc
+++ b/cpp/src/arrow/python/flight.cc
@@ -130,6 +130,22 @@ Status
PyFlightDataStream::Next(arrow::flight::FlightPayload* payload) {
return stream_->Next(payload);
}
+PyGeneratorFlightDataStream::PyGeneratorFlightDataStream(
+ PyObject* generator, std::shared_ptr<arrow::Schema> schema,
+ PyGeneratorFlightDataStreamCallback callback)
+ : schema_(schema), callback_(callback) {
+ Py_INCREF(generator);
+ generator_.reset(generator);
+}
+
+std::shared_ptr<arrow::Schema> PyGeneratorFlightDataStream::schema() { return
schema_; }
+
+Status PyGeneratorFlightDataStream::Next(arrow::flight::FlightPayload*
payload) {
+ PyAcquireGIL lock;
+ callback_(generator_.obj(), payload);
+ return CheckPyError();
+}
+
Status CreateFlightInfo(const std::shared_ptr<arrow::Schema>& schema,
const arrow::flight::FlightDescriptor& descriptor,
const std::vector<arrow::flight::FlightEndpoint>&
endpoints,
diff --git a/cpp/src/arrow/python/flight.h b/cpp/src/arrow/python/flight.h
index 128784f..8a655d2 100644
--- a/cpp/src/arrow/python/flight.h
+++ b/cpp/src/arrow/python/flight.h
@@ -83,6 +83,8 @@ typedef std::function<void(PyObject*,
std::unique_ptr<arrow::flight::Result>*)>
/// \brief A ResultStream built around a Python callback.
class ARROW_PYTHON_EXPORT PyFlightResultStream : public
arrow::flight::ResultStream {
public:
+ /// \brief Construct a FlightResultStream from a Python object and callback.
+ /// Must only be called while holding the GIL.
explicit PyFlightResultStream(PyObject* generator,
PyFlightResultStreamCallback callback);
Status Next(std::unique_ptr<arrow::flight::Result>* result) override;
@@ -96,6 +98,8 @@ class ARROW_PYTHON_EXPORT PyFlightResultStream : public
arrow::flight::ResultStr
/// Python object backing it.
class ARROW_PYTHON_EXPORT PyFlightDataStream : public
arrow::flight::FlightDataStream {
public:
+ /// \brief Construct a FlightDataStream from a Python object and underlying
stream.
+ /// Must only be called while holding the GIL.
explicit PyFlightDataStream(PyObject* data_source,
std::unique_ptr<arrow::flight::FlightDataStream>
stream);
std::shared_ptr<arrow::Schema> schema() override;
@@ -106,6 +110,28 @@ class ARROW_PYTHON_EXPORT PyFlightDataStream : public
arrow::flight::FlightDataS
std::unique_ptr<arrow::flight::FlightDataStream> stream_;
};
+/// \brief A callback that obtains the next payload from a Flight result
stream.
+typedef std::function<void(PyObject*, arrow::flight::FlightPayload*)>
+ PyGeneratorFlightDataStreamCallback;
+
+/// \brief A FlightDataStream built around a Python callback.
+class ARROW_PYTHON_EXPORT PyGeneratorFlightDataStream
+ : public arrow::flight::FlightDataStream {
+ public:
+ /// \brief Construct a FlightDataStream from a Python object and underlying
stream.
+ /// Must only be called while holding the GIL.
+ explicit PyGeneratorFlightDataStream(PyObject* generator,
+ std::shared_ptr<arrow::Schema> schema,
+ PyGeneratorFlightDataStreamCallback
callback);
+ std::shared_ptr<arrow::Schema> schema() override;
+ Status Next(arrow::flight::FlightPayload* payload) override;
+
+ private:
+ OwnedRefNoGIL generator_;
+ std::shared_ptr<arrow::Schema> schema_;
+ PyGeneratorFlightDataStreamCallback callback_;
+};
+
ARROW_PYTHON_EXPORT
Status CreateFlightInfo(const std::shared_ptr<arrow::Schema>& schema,
const arrow::flight::FlightDescriptor& descriptor,
diff --git a/python/pyarrow/_flight.pyx b/python/pyarrow/_flight.pyx
index 695513a..c79c738 100644
--- a/python/pyarrow/_flight.pyx
+++ b/python/pyarrow/_flight.pyx
@@ -445,6 +445,96 @@ cdef class RecordBatchStream(FlightDataStream):
return new CRecordBatchStream(reader)
+cdef class GeneratorStream(FlightDataStream):
+ """A Flight data stream backed by a Python generator."""
+ cdef:
+ shared_ptr[CSchema] schema
+ object generator
+ # A substream currently being consumed by the client, if
+ # present. Produced by the generator.
+ unique_ptr[CFlightDataStream] current_stream
+
+ def __init__(self, schema, generator):
+ """Create a GeneratorStream from a Python generator.
+
+ Parameters
+ ----------
+ schema : Schema
+ The schema for the data to be returned.
+
+ generator : iterator or iterable
+ The generator should yield other FlightDataStream objects,
+ Tables, RecordBatches, or RecordBatchReaders.
+ """
+ self.schema = pyarrow_unwrap_schema(schema)
+ self.generator = iter(generator)
+
+ cdef CFlightDataStream* to_stream(self):
+ cdef:
+ function[cb_data_stream_next] callback = &_data_stream_next
+ return new CPyGeneratorFlightDataStream(self, self.schema, callback)
+
+
+cdef void _data_stream_next(void* self, CFlightPayload* payload) except *:
+ """Callback for implementing FlightDataStream in Python."""
+ cdef:
+ unique_ptr[CFlightDataStream] data_stream
+
+ py_stream = <object> self
+ if not isinstance(py_stream, GeneratorStream):
+ raise RuntimeError("self object in callback is not GeneratorStream")
+ stream = <GeneratorStream> py_stream
+
+ if stream.current_stream != nullptr:
+ check_status(stream.current_stream.get().Next(payload))
+ # If the stream ended, see if there's another stream from the
+ # generator
+ if payload.ipc_message.metadata != nullptr:
+ return
+ stream.current_stream.reset(nullptr)
+
+ try:
+ result = next(stream.generator)
+ except StopIteration:
+ payload.ipc_message.metadata.reset(<CBuffer*> nullptr)
+ return
+
+ if isinstance(result, (Table, _CRecordBatchReader)):
+ result = RecordBatchStream(result)
+
+ stream_schema = pyarrow_wrap_schema(stream.schema)
+ if isinstance(result, FlightDataStream):
+ data_stream = unique_ptr[CFlightDataStream](
+ (<FlightDataStream> result).to_stream())
+ substream_schema = pyarrow_wrap_schema(data_stream.get().schema())
+ if substream_schema != stream_schema:
+ raise ValueError("Got a FlightDataStream whose schema does not "
+ "match the declared schema of this "
+ "GeneratorStream. "
+ "Got: {}\nExpected: {}".format(substream_schema,
+ stream_schema))
+ stream.current_stream.reset(
+ new CPyFlightDataStream(result, move(data_stream)))
+ _data_stream_next(self, payload)
+ elif isinstance(result, RecordBatch):
+ batch = <RecordBatch> result
+ if batch.schema != stream_schema:
+ raise ValueError("Got a RecordBatch whose schema does not "
+ "match the declared schema of this "
+ "GeneratorStream. "
+ "Got: {}\nExpected: {}".format(batch.schema,
+ stream_schema))
+ check_status(_GetRecordBatchPayload(
+ deref(batch.batch),
+ c_default_memory_pool(),
+ &payload.ipc_message))
+ else:
+ raise TypeError("GeneratorStream must be initialized with "
+ "an iterator of FlightDataStream, Table, "
+ "RecordBatch, or RecordBatchStreamReader objects, "
+ "not {}.".format(type(result)))
+
+
cdef void _list_flights(void* self, const CCriteria* c_criteria,
unique_ptr[CFlightListing]* listing) except *:
"""Callback for implementing ListFlights in Python."""
diff --git a/python/pyarrow/flight.py b/python/pyarrow/flight.py
index ae80a8b..f797419 100644
--- a/python/pyarrow/flight.py
+++ b/python/pyarrow/flight.py
@@ -23,6 +23,7 @@ from pyarrow._flight import (Action, # noqa
FlightEndpoint,
FlightInfo,
FlightServerBase,
+ GeneratorStream,
Location,
Ticket,
RecordBatchStream,
diff --git a/python/pyarrow/includes/libarrow_flight.pxd
b/python/pyarrow/includes/libarrow_flight.pxd
index 153f725..faf58cb 100644
--- a/python/pyarrow/includes/libarrow_flight.pxd
+++ b/python/pyarrow/includes/libarrow_flight.pxd
@@ -23,6 +23,20 @@ from pyarrow.includes.common cimport *
from pyarrow.includes.libarrow cimport *
+cdef extern from "arrow/ipc/api.h" namespace "arrow" nogil:
+ cdef cppclass CIpcPayload" arrow::ipc::internal::IpcPayload":
+ MessageType type
+ shared_ptr[CBuffer] metadata
+ vector[shared_ptr[CBuffer]] body_buffers
+ int64_t body_length
+
+ cdef CStatus _GetRecordBatchPayload\
+ " arrow::ipc::internal::GetRecordBatchPayload"(
+ const CRecordBatch& batch,
+ CMemoryPool* pool,
+ CIpcPayload* out)
+
+
cdef extern from "arrow/flight/api.h" namespace "arrow" nogil:
cdef cppclass CActionType" arrow::flight::ActionType":
c_string type
@@ -94,8 +108,13 @@ cdef extern from "arrow/flight/api.h" namespace "arrow"
nogil:
" arrow::flight::FlightMessageReader"(CRecordBatchReader):
CFlightDescriptor& descriptor()
+ cdef cppclass CFlightPayload" arrow::flight::FlightPayload":
+ shared_ptr[CBuffer] descriptor
+ CIpcPayload ipc_message
+
cdef cppclass CFlightDataStream" arrow::flight::FlightDataStream":
- pass
+ shared_ptr[CSchema] schema()
+ CStatus Next(CFlightPayload*)
cdef cppclass CRecordBatchStream \
" arrow::flight::RecordBatchStream"(CFlightDataStream):
@@ -133,6 +152,7 @@ ctypedef void cb_do_action(object, const CAction&,
unique_ptr[CResultStream]*)
ctypedef void cb_list_actions(object, vector[CActionType]*)
ctypedef void cb_result_next(object, unique_ptr[CResult]*)
+ctypedef void cb_data_stream_next(object, CFlightPayload*)
cdef extern from "arrow/python/flight.h" namespace "arrow::py::flight" nogil:
cdef cppclass PyFlightServerVtable:
@@ -161,6 +181,13 @@ cdef extern from "arrow/python/flight.h" namespace
"arrow::py::flight" nogil:
CPyFlightDataStream(object data_source,
unique_ptr[CFlightDataStream] stream)
+ cdef cppclass CPyGeneratorFlightDataStream\
+ " arrow::py::flight::PyGeneratorFlightDataStream"\
+ (CFlightDataStream):
+ CPyGeneratorFlightDataStream(object generator,
+ shared_ptr[CSchema] schema,
+ function[cb_data_stream_next] callback)
+
cdef CStatus CreateFlightInfo" arrow::py::flight::CreateFlightInfo"(
shared_ptr[CSchema] schema,
CFlightDescriptor& descriptor,
diff --git a/python/pyarrow/tests/test_flight.py
b/python/pyarrow/tests/test_flight.py
index b1b6a12..73dd018 100644
--- a/python/pyarrow/tests/test_flight.py
+++ b/python/pyarrow/tests/test_flight.py
@@ -57,6 +57,28 @@ class EchoFlightServer(flight.FlightServerBase):
self.last_message = reader.read_all()
+class EchoStreamFlightServer(EchoFlightServer):
+ """An echo server that streams individual record batches."""
+
+ def do_get(self, ticket):
+ return flight.GeneratorStream(
+ self.last_message.schema,
+ self.last_message.to_batches(chunksize=1024))
+
+
+class InvalidStreamFlightServer(flight.FlightServerBase):
+ """A Flight server that tries to return messages with differing schemas."""
+ data1 = [pa.array([-10, -5, 0, 5, 10])]
+ data2 = [pa.array([-10.0, -5.0, 0.0, 5.0, 10.0])]
+ table1 = pa.Table.from_arrays(data1, names=['a'])
+ table2 = pa.Table.from_arrays(data2, names=['a'])
+
+ schema = table1.schema
+
+ def do_get(self, ticket):
+ return flight.GeneratorStream(self.schema, [self.table1, self.table2])
+
+
@contextlib.contextmanager
def flight_server(server_base, *args, **kwargs):
"""Spawn a Flight server on a free port, shutting it down when done."""
@@ -114,3 +136,28 @@ def test_flight_large_message():
writer.close()
result = client.do_get(flight.Ticket(b''), data.schema).read_all()
assert result.equals(data)
+
+
+def test_flight_generator_stream():
+ """Try downloading a flight of RecordBatches in a GeneratorStream."""
+ data = pa.Table.from_arrays([
+ pa.array(range(0, 10 * 1024))
+ ], names=['a'])
+
+ with flight_server(EchoStreamFlightServer) as server_port:
+ client = flight.FlightClient.connect('localhost', server_port)
+ writer = client.do_put(flight.FlightDescriptor.for_path('test'),
+ data.schema)
+ writer.write_table(data)
+ writer.close()
+ result = client.do_get(flight.Ticket(b''), data.schema).read_all()
+ assert result.equals(data)
+
+
+def test_flight_invalid_generator_stream():
+ """Try streaming data with mismatched schemas."""
+ with flight_server(InvalidStreamFlightServer) as server_port:
+ client = flight.FlightClient.connect('localhost', server_port)
+ schema = InvalidStreamFlightServer.schema
+ with pytest.raises(pa.ArrowException):
+ client.do_get(flight.Ticket(b''), schema).read_all()