lidavidm commented on a change in pull request #10323:
URL: https://github.com/apache/arrow/pull/10323#discussion_r646837177



##########
File path: python/pyarrow/_flight.pyx
##########
@@ -1655,70 +1655,83 @@ cdef CStatus _data_stream_next(void* self, 
CFlightPayload* payload) except *:
         raise RuntimeError("self object in callback is not GeneratorStream")
     stream = <GeneratorStream> py_stream
 
-    if stream.current_stream != nullptr:
-        check_flight_status(stream.current_stream.get().Next(payload))
-        # If the stream ended, see if there's another stream from the
-        # generator
-        if payload.ipc_message.metadata != nullptr:
+    # The generator is allowed to yield a reader or table which we
+    # yield from; if that sub-generator is empty, we need to reset and
+    # try again. However, limit the number of attempts so that we
+    # don't just spin forever.
+    max_attempts = 128
+    for _ in range(max_attempts):
+        if stream.current_stream != nullptr:
+            check_flight_status(stream.current_stream.get().Next(payload))
+            # If the stream ended, see if there's another stream from the
+            # generator
+            if payload.ipc_message.metadata != nullptr:
+                return CStatus_OK()
+            stream.current_stream.reset(nullptr)
+
+        try:
+            result = next(stream.generator)
+        except StopIteration:
+            payload.ipc_message.metadata.reset(<CBuffer*> nullptr)
             return CStatus_OK()
-        stream.current_stream.reset(nullptr)
+        except FlightError as flight_error:
+            return (<FlightError> flight_error).to_status()
 
-    try:
-        result = next(stream.generator)
-    except StopIteration:
-        payload.ipc_message.metadata.reset(<CBuffer*> nullptr)
+        if isinstance(result, (list, tuple)):
+            result, metadata = result
+        else:
+            result, metadata = result, None
+
+        if isinstance(result, (Table, RecordBatchReader)):
+            if metadata:
+                raise ValueError("Can only return metadata alongside a "
+                                 "RecordBatch.")
+            result = RecordBatchStream(result)
+
+        stream_schema = pyarrow_wrap_schema(stream.schema)
+        if isinstance(result, FlightDataStream):
+            if metadata:
+                raise ValueError("Can only return metadata alongside a "
+                                 "RecordBatch.")
+            data_stream = unique_ptr[CFlightDataStream](
+                (<FlightDataStream> result).to_stream())
+            substream_schema = pyarrow_wrap_schema(data_stream.get().schema())
+            if substream_schema != stream_schema:
+                raise ValueError("Got a FlightDataStream whose schema "
+                                 "does not match the declared schema of this "
+                                 "GeneratorStream. "
+                                 "Got: {}\nExpected: {}".format(
+                                     substream_schema, stream_schema))
+            stream.current_stream.reset(
+                new CPyFlightDataStream(result, move(data_stream)))
+            # Loop around and try again
+            continue
+        elif isinstance(result, RecordBatch):
+            batch = <RecordBatch> result
+            if batch.schema != stream_schema:
+                raise ValueError("Got a RecordBatch whose schema does not "
+                                 "match the declared schema of this "
+                                 "GeneratorStream. "
+                                 "Got: {}\nExpected: {}".format(batch.schema,
+                                                                stream_schema))
+            check_flight_status(GetRecordBatchPayload(
+                deref(batch.batch),
+                stream.c_options,
+                &payload.ipc_message))
+            if metadata:
+                payload.app_metadata = pyarrow_unwrap_buffer(
+                    as_buffer(metadata))
+        else:
+            raise TypeError("GeneratorStream must be initialized with "
+                            "an iterator of FlightDataStream, Table, "
+                            "RecordBatch, or RecordBatchStreamReader objects, "
+                            "not {}.".format(type(result)))
+        # Don't loop around
         return CStatus_OK()
-    except FlightError as flight_error:
-        return (<FlightError> flight_error).to_status()
-
-    if isinstance(result, (list, tuple)):
-        result, metadata = result
-    else:
-        result, metadata = result, None
-
-    if isinstance(result, (Table, RecordBatchReader)):
-        if metadata:
-            raise ValueError("Can only return metadata alongside a "
-                             "RecordBatch.")
-        result = RecordBatchStream(result)
-
-    stream_schema = pyarrow_wrap_schema(stream.schema)
-    if isinstance(result, FlightDataStream):
-        if metadata:
-            raise ValueError("Can only return metadata alongside a "
-                             "RecordBatch.")
-        data_stream = unique_ptr[CFlightDataStream](
-            (<FlightDataStream> result).to_stream())
-        substream_schema = pyarrow_wrap_schema(data_stream.get().schema())
-        if substream_schema != stream_schema:
-            raise ValueError("Got a FlightDataStream whose schema does not "
-                             "match the declared schema of this "
-                             "GeneratorStream. "
-                             "Got: {}\nExpected: {}".format(substream_schema,
-                                                            stream_schema))
-        stream.current_stream.reset(
-            new CPyFlightDataStream(result, move(data_stream)))
-        return _data_stream_next(self, payload)
-    elif isinstance(result, RecordBatch):
-        batch = <RecordBatch> result
-        if batch.schema != stream_schema:
-            raise ValueError("Got a RecordBatch whose schema does not "
-                             "match the declared schema of this "
-                             "GeneratorStream. "
-                             "Got: {}\nExpected: {}".format(batch.schema,
-                                                            stream_schema))
-        check_flight_status(GetRecordBatchPayload(
-            deref(batch.batch),
-            stream.c_options,
-            &payload.ipc_message))
-        if metadata:
-            payload.app_metadata = pyarrow_unwrap_buffer(as_buffer(metadata))
-    else:
-        raise TypeError("GeneratorStream must be initialized with "
-                        "an iterator of FlightDataStream, Table, "
-                        "RecordBatch, or RecordBatchStreamReader objects, "
-                        "not {}.".format(type(result)))
-    return CStatus_OK()
+    # Ran out of attempts (the RPC handler kept yielding empty tables/readers)
+    raise RuntimeError("While getting next payload, ran out of attempts to "
+                       "get something to send "
+                       "(application server implementation error)")

Review comment:
       Though to be clear just in case, this function is adapting a generator 
of (batches, tables, readers, etc.) into a generator of just batches so I mean 
the inner table/reader/etc. here.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to