lidavidm commented on a change in pull request #10318:
URL: https://github.com/apache/arrow/pull/10318#discussion_r644153441



##########
File path: cpp/src/arrow/flight/client.h
##########
@@ -129,6 +133,12 @@ class ARROW_FLIGHT_EXPORT FlightStreamReader : public 
MetadataRecordBatchReader
  public:
   /// \brief Try to cancel the call.
   virtual void Cancel() = 0;
+  using MetadataRecordBatchReader::ReadAll;
+  /// \brief Consume entire stream as a vector of record batches
+  virtual Status ReadAll(std::vector<std::shared_ptr<RecordBatch>>* batches,
+                         const StopToken& stop_token) = 0;
+  /// \brief Consume entire stream as a Table
+  virtual Status ReadAll(std::shared_ptr<Table>* table, const StopToken& 
stop_token) = 0;

Review comment:
       I made this non-virtual.

##########
File path: python/pyarrow/tests/test_flight.py
##########
@@ -1810,3 +1813,62 @@ def test_generic_options():
                                 generic_options=options)
         with pytest.raises(pa.ArrowInvalid):
             client.do_get(flight.Ticket(b'ints'))
+
+
+class CancelFlightServer(FlightServerBase):
+    """A server for testing StopToken."""
+
+    def do_get(self, context, ticket):
+        schema = pa.schema([])
+        rb = pa.RecordBatch.from_arrays([], schema=schema)
+        return flight.GeneratorStream(schema, itertools.repeat(rb))
+
+    def do_exchange(self, context, descriptor, reader, writer):
+        schema = pa.schema([])
+        rb = pa.RecordBatch.from_arrays([], schema=schema)
+        writer.begin(schema)
+        while not context.is_cancelled():
+            # TODO: writing schema.empty_table() here hangs/fails
+            writer.write_batch(rb)
+            time.sleep(0.5)
+
+
+def test_interrupt():
+    if threading.current_thread().ident != threading.main_thread().ident:
+        pytest.skip("test only works from main Python thread")
+    # Skips test if not available
+    raise_signal = util.get_raise_signal()
+
+    def signal_from_thread():
+        time.sleep(0.5)
+        raise_signal(signal.SIGINT)
+
+    exc_types = (KeyboardInterrupt, pa.ArrowCancelled)
+
+    def test(read_all):
+        try:
+            try:
+                t = threading.Thread(target=signal_from_thread)
+                with pytest.raises(exc_types) as exc_info:
+                    t.start()
+                    read_all()
+            finally:
+                t.join()
+        except KeyboardInterrupt:
+            # In case KeyboardInterrupt didn't interrupt read_all
+            # above, at least prevent it from stopping the test suite
+            # pytest.fail("KeyboardInterrupt didn't interrupt Flight read_all")
+            raise
+        e = exc_info.value.__context__
+        assert isinstance(e, pa.ArrowCancelled) or isinstance(
+            e, pa.ArrowCancelled)

Review comment:
       I meant for the second to be KeyboardInterrupt.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to