This is an automated email from the ASF dual-hosted git repository.

wesm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new d90c159  ARROW-4858: [Flight/Python] enable FlightDataStream to be 
implemented in Python
d90c159 is described below

commit d90c1597ce488d27aca88bcc33bfffe1e96a1ac6
Author: David Li <[email protected]>
AuthorDate: Fri Mar 15 09:59:10 2019 -0500

    ARROW-4858: [Flight/Python] enable FlightDataStream to be implemented in 
Python
    
    This enables use cases like converting and streaming a large Pandas 
DataFrame in chunks, or generating/processing data on the fly.
    
    It could be simplified if we decide to only support iterators of 
RecordBatches (versus right now where it accepts Tables, RecordBatches, other 
FlightDataStreams, etc.)
    
    Author: David Li <[email protected]>
    
    Closes #3890 from lihalite/arrow-py-data and squashes the following commits:
    
    ed108cfd5 <David Li> Document that Flight/Python helpers require GIL to 
construct
    0bcd8e144 <David Li> Check that PyGeneratorFlightDataStream returns 
consistent schema
    27298945e <David Li> Enable FlightDataStream to be implemented in Python
---
 cpp/src/arrow/python/flight.cc              | 16 +++++
 cpp/src/arrow/python/flight.h               | 26 +++++++++
 python/pyarrow/_flight.pyx                  | 90 +++++++++++++++++++++++++++++
 python/pyarrow/flight.py                    |  1 +
 python/pyarrow/includes/libarrow_flight.pxd | 29 +++++++++-
 python/pyarrow/tests/test_flight.py         | 47 +++++++++++++++
 6 files changed, 208 insertions(+), 1 deletion(-)

diff --git a/cpp/src/arrow/python/flight.cc b/cpp/src/arrow/python/flight.cc
index ec25d32..326245b 100644
--- a/cpp/src/arrow/python/flight.cc
+++ b/cpp/src/arrow/python/flight.cc
@@ -130,6 +130,22 @@ Status 
PyFlightDataStream::Next(arrow::flight::FlightPayload* payload) {
   return stream_->Next(payload);
 }
 
+PyGeneratorFlightDataStream::PyGeneratorFlightDataStream(
+    PyObject* generator, std::shared_ptr<arrow::Schema> schema,
+    PyGeneratorFlightDataStreamCallback callback)
+    : schema_(schema), callback_(callback) {
+  Py_INCREF(generator);
+  generator_.reset(generator);
+}
+
+std::shared_ptr<arrow::Schema> PyGeneratorFlightDataStream::schema() { return 
schema_; }
+
+Status PyGeneratorFlightDataStream::Next(arrow::flight::FlightPayload* 
payload) {
+  PyAcquireGIL lock;
+  callback_(generator_.obj(), payload);
+  return CheckPyError();
+}
+
 Status CreateFlightInfo(const std::shared_ptr<arrow::Schema>& schema,
                         const arrow::flight::FlightDescriptor& descriptor,
                         const std::vector<arrow::flight::FlightEndpoint>& 
endpoints,
diff --git a/cpp/src/arrow/python/flight.h b/cpp/src/arrow/python/flight.h
index 128784f..8a655d2 100644
--- a/cpp/src/arrow/python/flight.h
+++ b/cpp/src/arrow/python/flight.h
@@ -83,6 +83,8 @@ typedef std::function<void(PyObject*, 
std::unique_ptr<arrow::flight::Result>*)>
 /// \brief A ResultStream built around a Python callback.
 class ARROW_PYTHON_EXPORT PyFlightResultStream : public 
arrow::flight::ResultStream {
  public:
+  /// \brief Construct a FlightResultStream from a Python object and callback.
+  /// Must only be called while holding the GIL.
   explicit PyFlightResultStream(PyObject* generator,
                                 PyFlightResultStreamCallback callback);
   Status Next(std::unique_ptr<arrow::flight::Result>* result) override;
@@ -96,6 +98,8 @@ class ARROW_PYTHON_EXPORT PyFlightResultStream : public 
arrow::flight::ResultStr
 /// Python object backing it.
 class ARROW_PYTHON_EXPORT PyFlightDataStream : public 
arrow::flight::FlightDataStream {
  public:
+  /// \brief Construct a FlightDataStream from a Python object and underlying 
stream.
+  /// Must only be called while holding the GIL.
   explicit PyFlightDataStream(PyObject* data_source,
                               std::unique_ptr<arrow::flight::FlightDataStream> 
stream);
   std::shared_ptr<arrow::Schema> schema() override;
@@ -106,6 +110,28 @@ class ARROW_PYTHON_EXPORT PyFlightDataStream : public 
arrow::flight::FlightDataS
   std::unique_ptr<arrow::flight::FlightDataStream> stream_;
 };
 
+/// \brief A callback that obtains the next payload from a Flight result 
stream.
+typedef std::function<void(PyObject*, arrow::flight::FlightPayload*)>
+    PyGeneratorFlightDataStreamCallback;
+
+/// \brief A FlightDataStream built around a Python callback.
+class ARROW_PYTHON_EXPORT PyGeneratorFlightDataStream
+    : public arrow::flight::FlightDataStream {
+ public:
+  /// \brief Construct a FlightDataStream from a Python object and underlying 
stream.
+  /// Must only be called while holding the GIL.
+  explicit PyGeneratorFlightDataStream(PyObject* generator,
+                                       std::shared_ptr<arrow::Schema> schema,
+                                       PyGeneratorFlightDataStreamCallback 
callback);
+  std::shared_ptr<arrow::Schema> schema() override;
+  Status Next(arrow::flight::FlightPayload* payload) override;
+
+ private:
+  OwnedRefNoGIL generator_;
+  std::shared_ptr<arrow::Schema> schema_;
+  PyGeneratorFlightDataStreamCallback callback_;
+};
+
 ARROW_PYTHON_EXPORT
 Status CreateFlightInfo(const std::shared_ptr<arrow::Schema>& schema,
                         const arrow::flight::FlightDescriptor& descriptor,
diff --git a/python/pyarrow/_flight.pyx b/python/pyarrow/_flight.pyx
index 695513a..c79c738 100644
--- a/python/pyarrow/_flight.pyx
+++ b/python/pyarrow/_flight.pyx
@@ -445,6 +445,96 @@ cdef class RecordBatchStream(FlightDataStream):
         return new CRecordBatchStream(reader)
 
 
+cdef class GeneratorStream(FlightDataStream):
+    """A Flight data stream backed by a Python generator."""
+    cdef:
+        shared_ptr[CSchema] schema
+        object generator
+        # A substream currently being consumed by the client, if
+        # present. Produced by the generator.
+        unique_ptr[CFlightDataStream] current_stream
+
+    def __init__(self, schema, generator):
+        """Create a GeneratorStream from a Python generator.
+
+        Parameters
+        ----------
+        schema : Schema
+            The schema for the data to be returned.
+
+        generator : iterator or iterable
+            The generator should yield other FlightDataStream objects,
+            Tables, RecordBatches, or RecordBatchReaders.
+        """
+        self.schema = pyarrow_unwrap_schema(schema)
+        self.generator = iter(generator)
+
+    cdef CFlightDataStream* to_stream(self):
+        cdef:
+            function[cb_data_stream_next] callback = &_data_stream_next
+        return new CPyGeneratorFlightDataStream(self, self.schema, callback)
+
+
+cdef void _data_stream_next(void* self, CFlightPayload* payload) except *:
+    """Callback for implementing FlightDataStream in Python."""
+    cdef:
+        unique_ptr[CFlightDataStream] data_stream
+
+    py_stream = <object> self
+    if not isinstance(py_stream, GeneratorStream):
+        raise RuntimeError("self object in callback is not GeneratorStream")
+    stream = <GeneratorStream> py_stream
+
+    if stream.current_stream != nullptr:
+        check_status(stream.current_stream.get().Next(payload))
+        # If the stream ended, see if there's another stream from the
+        # generator
+        if payload.ipc_message.metadata != nullptr:
+            return
+        stream.current_stream.reset(nullptr)
+
+    try:
+        result = next(stream.generator)
+    except StopIteration:
+        payload.ipc_message.metadata.reset(<CBuffer*> nullptr)
+        return
+
+    if isinstance(result, (Table, _CRecordBatchReader)):
+        result = RecordBatchStream(result)
+
+    stream_schema = pyarrow_wrap_schema(stream.schema)
+    if isinstance(result, FlightDataStream):
+        data_stream = unique_ptr[CFlightDataStream](
+            (<FlightDataStream> result).to_stream())
+        substream_schema = pyarrow_wrap_schema(data_stream.get().schema())
+        if substream_schema != stream_schema:
+            raise ValueError("Got a FlightDataStream whose schema does not "
+                             "match the declared schema of this "
+                             "GeneratorStream. "
+                             "Got: {}\nExpected: {}".format(substream_schema,
+                                                            stream_schema))
+        stream.current_stream.reset(
+            new CPyFlightDataStream(result, move(data_stream)))
+        _data_stream_next(self, payload)
+    elif isinstance(result, RecordBatch):
+        batch = <RecordBatch> result
+        if batch.schema != stream_schema:
+            raise ValueError("Got a RecordBatch whose schema does not "
+                             "match the declared schema of this "
+                             "GeneratorStream. "
+                             "Got: {}\nExpected: {}".format(batch.schema,
+                                                            stream_schema))
+        check_status(_GetRecordBatchPayload(
+            deref(batch.batch),
+            c_default_memory_pool(),
+            &payload.ipc_message))
+    else:
+        raise TypeError("GeneratorStream must be initialized with "
+                        "an iterator of FlightDataStream, Table, "
+                        "RecordBatch, or RecordBatchStreamReader objects, "
+                        "not {}.".format(type(result)))
+
+
 cdef void _list_flights(void* self, const CCriteria* c_criteria,
                         unique_ptr[CFlightListing]* listing) except *:
     """Callback for implementing ListFlights in Python."""
diff --git a/python/pyarrow/flight.py b/python/pyarrow/flight.py
index ae80a8b..f797419 100644
--- a/python/pyarrow/flight.py
+++ b/python/pyarrow/flight.py
@@ -23,6 +23,7 @@ from pyarrow._flight import (Action,  # noqa
                              FlightEndpoint,
                              FlightInfo,
                              FlightServerBase,
+                             GeneratorStream,
                              Location,
                              Ticket,
                              RecordBatchStream,
diff --git a/python/pyarrow/includes/libarrow_flight.pxd 
b/python/pyarrow/includes/libarrow_flight.pxd
index 153f725..faf58cb 100644
--- a/python/pyarrow/includes/libarrow_flight.pxd
+++ b/python/pyarrow/includes/libarrow_flight.pxd
@@ -23,6 +23,20 @@ from pyarrow.includes.common cimport *
 from pyarrow.includes.libarrow cimport *
 
 
+cdef extern from "arrow/ipc/api.h" namespace "arrow" nogil:
+    cdef cppclass CIpcPayload" arrow::ipc::internal::IpcPayload":
+        MessageType type
+        shared_ptr[CBuffer] metadata
+        vector[shared_ptr[CBuffer]] body_buffers
+        int64_t body_length
+
+    cdef CStatus _GetRecordBatchPayload\
+        " arrow::ipc::internal::GetRecordBatchPayload"(
+            const CRecordBatch& batch,
+            CMemoryPool* pool,
+            CIpcPayload* out)
+
+
 cdef extern from "arrow/flight/api.h" namespace "arrow" nogil:
     cdef cppclass CActionType" arrow::flight::ActionType":
         c_string type
@@ -94,8 +108,13 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" 
nogil:
             " arrow::flight::FlightMessageReader"(CRecordBatchReader):
         CFlightDescriptor& descriptor()
 
+    cdef cppclass CFlightPayload" arrow::flight::FlightPayload":
+        shared_ptr[CBuffer] descriptor
+        CIpcPayload ipc_message
+
     cdef cppclass CFlightDataStream" arrow::flight::FlightDataStream":
-        pass
+        shared_ptr[CSchema] schema()
+        CStatus Next(CFlightPayload*)
 
     cdef cppclass CRecordBatchStream \
             " arrow::flight::RecordBatchStream"(CFlightDataStream):
@@ -133,6 +152,7 @@ ctypedef void cb_do_action(object, const CAction&,
                            unique_ptr[CResultStream]*)
 ctypedef void cb_list_actions(object, vector[CActionType]*)
 ctypedef void cb_result_next(object, unique_ptr[CResult]*)
+ctypedef void cb_data_stream_next(object, CFlightPayload*)
 
 cdef extern from "arrow/python/flight.h" namespace "arrow::py::flight" nogil:
     cdef cppclass PyFlightServerVtable:
@@ -161,6 +181,13 @@ cdef extern from "arrow/python/flight.h" namespace 
"arrow::py::flight" nogil:
         CPyFlightDataStream(object data_source,
                             unique_ptr[CFlightDataStream] stream)
 
+    cdef cppclass CPyGeneratorFlightDataStream\
+            " arrow::py::flight::PyGeneratorFlightDataStream"\
+            (CFlightDataStream):
+        CPyGeneratorFlightDataStream(object generator,
+                                     shared_ptr[CSchema] schema,
+                                     function[cb_data_stream_next] callback)
+
     cdef CStatus CreateFlightInfo" arrow::py::flight::CreateFlightInfo"(
         shared_ptr[CSchema] schema,
         CFlightDescriptor& descriptor,
diff --git a/python/pyarrow/tests/test_flight.py 
b/python/pyarrow/tests/test_flight.py
index b1b6a12..73dd018 100644
--- a/python/pyarrow/tests/test_flight.py
+++ b/python/pyarrow/tests/test_flight.py
@@ -57,6 +57,28 @@ class EchoFlightServer(flight.FlightServerBase):
         self.last_message = reader.read_all()
 
 
+class EchoStreamFlightServer(EchoFlightServer):
+    """An echo server that streams individual record batches."""
+
+    def do_get(self, ticket):
+        return flight.GeneratorStream(
+            self.last_message.schema,
+            self.last_message.to_batches(chunksize=1024))
+
+
+class InvalidStreamFlightServer(flight.FlightServerBase):
+    """A Flight server that tries to return messages with differing schemas."""
+    data1 = [pa.array([-10, -5, 0, 5, 10])]
+    data2 = [pa.array([-10.0, -5.0, 0.0, 5.0, 10.0])]
+    table1 = pa.Table.from_arrays(data1, names=['a'])
+    table2 = pa.Table.from_arrays(data2, names=['a'])
+
+    schema = table1.schema
+
+    def do_get(self, ticket):
+        return flight.GeneratorStream(self.schema, [self.table1, self.table2])
+
+
 @contextlib.contextmanager
 def flight_server(server_base, *args, **kwargs):
     """Spawn a Flight server on a free port, shutting it down when done."""
@@ -114,3 +136,28 @@ def test_flight_large_message():
         writer.close()
         result = client.do_get(flight.Ticket(b''), data.schema).read_all()
         assert result.equals(data)
+
+
+def test_flight_generator_stream():
+    """Try downloading a flight of RecordBatches in a GeneratorStream."""
+    data = pa.Table.from_arrays([
+        pa.array(range(0, 10 * 1024))
+    ], names=['a'])
+
+    with flight_server(EchoStreamFlightServer) as server_port:
+        client = flight.FlightClient.connect('localhost', server_port)
+        writer = client.do_put(flight.FlightDescriptor.for_path('test'),
+                               data.schema)
+        writer.write_table(data)
+        writer.close()
+        result = client.do_get(flight.Ticket(b''), data.schema).read_all()
+        assert result.equals(data)
+
+
+def test_flight_invalid_generator_stream():
+    """Try streaming data with mismatched schemas."""
+    with flight_server(InvalidStreamFlightServer) as server_port:
+        client = flight.FlightClient.connect('localhost', server_port)
+        schema = InvalidStreamFlightServer.schema
+        with pytest.raises(pa.ArrowException):
+            client.do_get(flight.Ticket(b''), schema).read_all()

Reply via email to