This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new c7f5306c feat(python/adbc_driver_manager): handle KeyboardInterrupt 
(#1509)
c7f5306c is described below

commit c7f5306c7341509222573693cfdddbceb1c6f676
Author: David Li <li.david...@gmail.com>
AuthorDate: Tue Feb 6 14:40:46 2024 -0500

    feat(python/adbc_driver_manager): handle KeyboardInterrupt (#1509)
    
    Alternative to #1486.
    
    Fixes #1484.
---
 .github/workflows/native-unix.yml                  |  13 +
 ci/scripts/python_test.sh                          |   4 +-
 docker-compose.yml                                 |   6 +
 python/adbc_driver_flightsql/tests/test_errors.py  |  20 ++
 python/adbc_driver_manager/MANIFEST.in             |   2 +
 .../adbc_driver_manager/_blocking_impl.cc          | 269 +++++++++++++++++++++
 .../adbc_driver_manager/_blocking_impl.h           |  38 +++
 .../adbc_driver_manager/_lib.pyi                   |  11 +
 .../adbc_driver_manager/_lib.pyx                   |  99 ++++++++
 .../adbc_driver_manager/dbapi.py                   |  17 +-
 python/adbc_driver_manager/pyproject.toml          |   1 +
 python/adbc_driver_manager/setup.py                |   3 +
 python/adbc_driver_manager/tests/test_blocking.py  | 141 +++++++++++
 13 files changed, 617 insertions(+), 7 deletions(-)

diff --git a/.github/workflows/native-unix.yml 
b/.github/workflows/native-unix.yml
index 142d5f4b..c4ab9f44 100644
--- a/.github/workflows/native-unix.yml
+++ b/.github/workflows/native-unix.yml
@@ -477,7 +477,20 @@ jobs:
       - name: Test Python Driver Flight SQL
         shell: bash -l {0}
         run: |
+          # Can't use Docker on macOS
+          pushd $(pwd)/go/adbc
+          go build -o testserver ./driver/flightsql/cmd/testserver
+          popd
+          $(pwd)/go/adbc/testserver -host 0.0.0.0 -port 41414 &
+          while ! curl --http2-prior-knowledge -H "content-type: 
application/grpc" -v localhost:41414 -XPOST;
+          do
+              echo "Waiting for test server..."
+              jobs
+              sleep 5
+          done
+          export ADBC_TEST_FLIGHTSQL_URI=grpc://localhost:41414
           env BUILD_ALL=0 BUILD_DRIVER_FLIGHTSQL=1 ./ci/scripts/python_test.sh 
"$(pwd)" "$(pwd)/build" "$HOME/local"
+          kill %1
       - name: Build Python Driver PostgreSQL
         shell: bash -l {0}
         run: |
diff --git a/ci/scripts/python_test.sh b/ci/scripts/python_test.sh
index f8d70917..6f95b589 100755
--- a/ci/scripts/python_test.sh
+++ b/ci/scripts/python_test.sh
@@ -58,8 +58,8 @@ test_subproject() {
     fi
 
     echo "=== Testing ${subproject} ==="
-    echo env ${options[@]} python -m pytest -vv 
"${source_dir}/python/${subproject}/tests"
-    env ${options[@]} python -m pytest -vv 
"${source_dir}/python/${subproject}/tests"
+    echo env ${options[@]} python -m pytest -vvs --full-trace 
"${source_dir}/python/${subproject}/tests"
+    env ${options[@]} python -m pytest -vvs --full-trace 
"${source_dir}/python/${subproject}/tests"
     echo
 }
 
diff --git a/docker-compose.yml b/docker-compose.yml
index 89394e59..789d5d45 100644
--- a/docker-compose.yml
+++ b/docker-compose.yml
@@ -150,6 +150,12 @@ services:
       dockerfile: ci/docker/flightsql-test.dockerfile
       args:
         GO: ${GO}
+    healthcheck:
+      test: ["CMD", "curl", "--http2-prior-knowledge", "-XPOST", "-H", 
"content-type: application/grpc"]
+      interval: 5s
+      timeout: 30s
+      retries: 3
+      start_period: 5m
     ports:
       - "41414:41414"
     volumes:
diff --git a/python/adbc_driver_flightsql/tests/test_errors.py 
b/python/adbc_driver_flightsql/tests/test_errors.py
index ed44b6a3..ee2b62d3 100644
--- a/python/adbc_driver_flightsql/tests/test_errors.py
+++ b/python/adbc_driver_flightsql/tests/test_errors.py
@@ -16,6 +16,8 @@
 # under the License.
 
 import re
+import threading
+import time
 
 import google.protobuf.any_pb2 as any_pb2
 import google.protobuf.wrappers_pb2 as wrappers_pb2
@@ -45,6 +47,24 @@ def test_query_cancel(test_dbapi):
             cur.fetchone()
 
 
+def test_query_cancel_async(test_dbapi):
+    with test_dbapi.cursor() as cur:
+        cur.execute("forever")
+
+        def _cancel():
+            time.sleep(2)
+            cur.adbc_cancel()
+
+        t = threading.Thread(target=_cancel, daemon=True)
+        t.start()
+
+        with pytest.raises(
+            test_dbapi.OperationalError,
+            match=re.escape("CANCELLED: [FlightSQL] context canceled"),
+        ):
+            cur.fetchone()
+
+
 def test_query_error_fetch(test_dbapi):
     with test_dbapi.cursor() as cur:
         cur.execute("error_do_get")
diff --git a/python/adbc_driver_manager/MANIFEST.in 
b/python/adbc_driver_manager/MANIFEST.in
index 306c3114..298ff3a9 100644
--- a/python/adbc_driver_manager/MANIFEST.in
+++ b/python/adbc_driver_manager/MANIFEST.in
@@ -22,6 +22,8 @@ include NOTICE.txt
 include adbc_driver_manager/adbc.h
 include adbc_driver_manager/adbc_driver_manager.cc
 include adbc_driver_manager/adbc_driver_manager.h
+include adbc_driver_manager/_blocking_impl.cc
+include adbc_driver_manager/_blocking_impl.h
 include adbc_driver_manager/_lib.pxd
 include adbc_driver_manager/_lib.pyi
 include adbc_driver_manager/_reader.pyi
diff --git a/python/adbc_driver_manager/adbc_driver_manager/_blocking_impl.cc 
b/python/adbc_driver_manager/adbc_driver_manager/_blocking_impl.cc
new file mode 100644
index 00000000..766b3964
--- /dev/null
+++ b/python/adbc_driver_manager/adbc_driver_manager/_blocking_impl.cc
@@ -0,0 +1,269 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "_blocking_impl.h"
+
+#if defined(_WIN32)
+#define NOMINMAX
+#define WIN32_LEAN_AND_MEAN
+#include <errno.h>
+#include <fcntl.h>
+#include <io.h>
+#include <windows.h>
+#else
+#include <fcntl.h>
+#include <pthread.h>
+#include <unistd.h>
+#endif
+
+#include <csignal>
+#include <cstring>
+#include <iostream>
+#include <mutex>
+#include <thread>
+
+namespace pyadbc_driver_manager {
+
+// This is somewhat derived from io_util.cc in arrow, but that implementation
+// isn't easily used outside of Arrow's monolith.
+namespace {
+static std::once_flag kInitOnce;
+// We may encounter errors below that we can't do anything about. Use this to
+// print out an error, once.
+static std::once_flag kWarnOnce;
+// This thread reads from a pipe forever.  Whenever it reads something, it
+// calls the callback below.
+static std::thread kCancelThread;
+
+static std::mutex cancel_mutex;
+// This callback is registered by the Python side; basically it will call
+// cancel() on an ADBC object.
+static void (*cancel_callback)(void*) = nullptr;
+// Callback state (a pointer to the ADBC PyObject).
+static void* cancel_callback_data = nullptr;
+// A nonblocking self-pipe.
+static int pipe[2];
+#if defined(_WIN32)
+void (*old_sigint)(int);
+#else
+// The old signal handler (most likely Python's).
+struct sigaction old_sigint;
+// Our signal handler (below).
+struct sigaction our_sigint;
+#endif
+
+std::string MakePipe() {
+  int rc = 0;
+#if defined(__linux__) && defined(__GLIBC__)
+  rc = pipe2(pipe, O_CLOEXEC);
+#elif defined(_WIN32)
+  rc = _pipe(pipe, 4096, _O_BINARY);
+#else
+  rc = ::pipe(pipe);
+#endif
+
+  if (rc != 0) {
+    return std::strerror(errno);
+  }
+
+#if (!defined(__linux__) || !defined(__GLIBC__)) && !defined(_WIN32)
+  {
+    int flags = fcntl(pipe[0], F_GETFD, 0);
+    if (flags < 0) {
+      return std::strerror(errno);
+    }
+    rc = fcntl(pipe[0], F_SETFD, flags | FD_CLOEXEC);
+    if (rc < 0) {
+      return std::strerror(errno);
+    }
+
+    flags = fcntl(pipe[1], F_GETFD, 0);
+    if (flags < 0) {
+      return std::strerror(errno);
+    }
+    rc = fcntl(pipe[1], F_SETFD, flags | FD_CLOEXEC);
+    if (rc < 0) {
+      return std::strerror(errno);
+    }
+  }
+#endif
+
+  // Make the write side nonblocking (the read side should stay blocking!)
+#if defined(_WIN32)
+  const auto handle = reinterpret_cast<HANDLE>(_get_osfhandle(pipe[1]));
+  DWORD mode = PIPE_NOWAIT;
+  if (!SetNamedPipeHandleState(handle, &mode, nullptr, nullptr)) {
+    DWORD last_error = GetLastError();
+    LPVOID message;
+
+    FormatMessage(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM |
+                      FORMAT_MESSAGE_IGNORE_INSERTS,
+                  /*lpSource=*/nullptr, last_error,
+                  MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT),
+                  reinterpret_cast<LPSTR>(&message), /*nSize=*/0, 
/*Arguments=*/nullptr);
+
+    std::string buffer = "(";
+    buffer += std::to_string(last_error);
+    buffer += ") ";
+    buffer += reinterpret_cast<char*>(message);
+    LocalFree(message);
+    return buffer;
+  }
+#else
+  {
+    int flags = fcntl(pipe[1], F_GETFL, 0);
+    if (flags < 0) {
+      return std::strerror(errno);
+    }
+    rc = fcntl(pipe[1], F_SETFL, flags | O_NONBLOCK);
+    if (rc < 0) {
+      return std::strerror(errno);
+    }
+  }
+#endif
+
+  return "";
+}
+
+void InterruptThread() {
+#if defined(__APPLE__)
+  pthread_setname_np("AdbcInterrupt");
+#endif
+
+  while (true) {
+    char buf = 0;
+    // Anytime something is written to the pipe, attempt to call the callback
+    auto bytes_read = read(pipe[0], &buf, 1);
+    if (bytes_read < 0) {
+      if (errno == EINTR) continue;
+
+      // XXX: we failed reading from the pipe
+      std::string message = std::strerror(errno);
+      std::call_once(kWarnOnce, [&]() {
+        std::cerr << "adbc_driver_manager (native code): error handling 
interrupt: "
+                  << message << std::endl;
+      });
+    } else if (bytes_read > 0) {
+      // Save the callback locally instead of calling it under the lock, since
+      // otherwise we may deadlock with the Python side trying to call us
+      void (*local_callback)(void*) = nullptr;
+      void* local_callback_data = nullptr;
+
+      {
+        std::lock_guard<std::mutex> lock(cancel_mutex);
+        if (cancel_callback != nullptr) {
+          local_callback = cancel_callback;
+          local_callback_data = cancel_callback_data;
+        }
+        cancel_callback = nullptr;
+        cancel_callback_data = nullptr;
+      }
+
+      if (local_callback != nullptr) {
+        local_callback(local_callback_data);
+      }
+    }
+  }
+}
+
+// We can't do much about failures here, so ignore the result.  If the pipe is
+// full, that's fine; it just means the thread has fallen behind in processing
+// earlier interrupts.
+void SigintHandler(int) {
+#if defined(_WIN32)
+  (void)_write(pipe[1], "X", 1);
+#else
+  (void)write(pipe[1], "X", 1);
+#endif
+}
+
+}  // namespace
+
+std::string InitBlockingCallback() {
+  std::string error;
+  std::call_once(kInitOnce, [&]() {
+    error = MakePipe();
+    if (!error.empty()) {
+      return;
+    }
+
+#if !defined(_WIN32)
+    our_sigint.sa_handler = &SigintHandler;
+    our_sigint.sa_flags = 0;
+    sigemptyset(&our_sigint.sa_mask);
+#endif
+
+    kCancelThread = std::thread(InterruptThread);
+#if defined(__linux__)
+    pthread_setname_np(kCancelThread.native_handle(), "AdbcInterrupt");
+#endif
+    kCancelThread.detach();
+  });
+  return error;
+}
+
+std::string SetBlockingCallback(void (*callback)(void*), void* data) {
+  std::lock_guard<std::mutex> lock(cancel_mutex);
+  cancel_callback = callback;
+  cancel_callback_data = data;
+
+#if defined(_WIN32)
+  if (old_sigint == nullptr) {
+    old_sigint = signal(SIGINT, &SigintHandler);
+    if (old_sigint == SIG_ERR) {
+      old_sigint = nullptr;
+      return std::strerror(errno);
+    }
+  }
+#else
+  // Don't set the handler again if we're somehow called twice
+  if (old_sigint.sa_handler == nullptr && old_sigint.sa_sigaction == nullptr) {
+    int rc = sigaction(SIGINT, &our_sigint, &old_sigint);
+    if (rc != 0) {
+      return std::strerror(errno);
+    }
+  }
+#endif
+  return "";
+}
+
+std::string ClearBlockingCallback() {
+  std::lock_guard<std::mutex> lock(cancel_mutex);
+  cancel_callback = nullptr;
+  cancel_callback_data = nullptr;
+
+#if defined(_WIN32)
+  if (old_sigint != nullptr) {
+    auto rc = signal(SIGINT, old_sigint);
+    old_sigint = nullptr;
+    if (rc == SIG_ERR) {
+      return std::strerror(errno);
+    }
+  }
+#else
+  if (old_sigint.sa_handler != nullptr || old_sigint.sa_sigaction != nullptr) {
+    int rc = sigaction(SIGINT, &old_sigint, nullptr);
+    std::memset(&old_sigint, 0, sizeof(old_sigint));
+    if (rc != 0) {
+      return std::strerror(errno);
+    }
+  }
+#endif
+  return "";
+}
+
+}  // namespace pyadbc_driver_manager
diff --git a/python/adbc_driver_manager/adbc_driver_manager/_blocking_impl.h 
b/python/adbc_driver_manager/adbc_driver_manager/_blocking_impl.h
new file mode 100644
index 00000000..ac76252f
--- /dev/null
+++ b/python/adbc_driver_manager/adbc_driver_manager/_blocking_impl.h
@@ -0,0 +1,38 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+/// Allow KeyboardInterrupt to function with ADBC in Python.
+///
+/// Call SetBlockingCallback to register a callback.  This will temporarily
+/// suppress the Python SIGINT handler.  When SIGINT is received, this module
+/// will handle it by calling the callback.
+
+#include <string>
+
+namespace pyadbc_driver_manager {
+
+/// \brief Set up internal state to handle.
+/// \return An error message (or empty string).
+std::string InitBlockingCallback();
+/// \brief Set the callback for when SIGINT is received.
+/// \return An error message (or empty string).
+std::string SetBlockingCallback(void (*callback)(void*), void* data);
+/// \brief Clear the callback for when SIGINT is received.
+/// \return An error message (or empty string).
+std::string ClearBlockingCallback();
+
+}  // namespace pyadbc_driver_manager
diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyi 
b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyi
index 7afada9e..2a818839 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyi
+++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyi
@@ -17,6 +17,7 @@
 
 # NOTE: generated with mypy's stubgen, then hand-edited to fix things
 
+import typing_extensions
 from typing import Any, ClassVar, Dict, List, Literal, Optional, Tuple, Union
 
 from typing import overload
@@ -201,3 +202,13 @@ def _test_error(
     vendor_code: Optional[int],
     sqlstate: Optional[str],
 ) -> Error: ...
+
+_P = typing_extensions.ParamSpec("_P")
+_T = typing.TypeVar("_T")
+
+def _blocking_call(
+    func: typing.Callable[_P, _T],
+    args: tuple,
+    kwargs: dict,
+    cancel: typing.Callable[[], None],
+) -> _T: ...
diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx 
b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
index 91139100..309cd764 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
+++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
@@ -20,8 +20,12 @@
 """Low-level ADBC API."""
 
 import enum
+import functools
 import threading
+import os
 import typing
+import sys
+import warnings
 from typing import List, Tuple
 
 cimport cpython
@@ -33,6 +37,7 @@ from cpython.pycapsule cimport (
 from libc.stdint cimport int32_t, int64_t, uint8_t, uint32_t, uintptr_t
 from libc.stdlib cimport malloc, free
 from libc.string cimport memcpy, memset
+from libcpp.string cimport string as c_string
 from libcpp.vector cimport vector as c_vector
 
 if typing.TYPE_CHECKING:
@@ -1481,3 +1486,97 @@ cdef class AdbcStatement(_AdbcHandle):
 cdef const CAdbcError* PyAdbcErrorFromArrayStream(
     CArrowArrayStream* stream, CAdbcStatusCode* status):
     return AdbcErrorFromArrayStream(stream, status)
+
+
+cdef extern from "_blocking_impl.h" nogil:
+    ctypedef void (*BlockingCallback)(void*) noexcept nogil
+    c_string 
CInitBlockingCallback"pyadbc_driver_manager::InitBlockingCallback"()
+    c_string 
CSetBlockingCallback"pyadbc_driver_manager::SetBlockingCallback"(BlockingCallback,
 void* data)
+    c_string 
CClearBlockingCallback"pyadbc_driver_manager::ClearBlockingCallback"()
+
+
+@functools.cache
+def _init_blocking_call():
+    error = bytes(CInitBlockingCallback()).decode("utf-8")
+    if error:
+        warnings.warn(
+            f"Failed to initialize KeyboardInterrupt support: {error}",
+            RuntimeWarning,
+        )
+
+
+_blocking_lock = threading.Lock()
+_blocking_exc = None
+
+
+def _blocking_call_impl(func, args, kwargs, cancel):
+    """
+    Run functions that are expected to block with a native SIGINT handler.
+
+    Parameters
+    ----------
+    """
+    global _blocking_exc
+
+    if threading.current_thread() is not threading.main_thread():
+        return func(*args, **kwargs)
+
+    _init_blocking_call()
+
+    with _blocking_lock:
+        if _blocking_exc:
+            _blocking_exc = None
+
+    # Set the callback for the background thread and save the signal handler
+    # TODO: ideally this would be no-op if already set
+    error = bytes(
+        CSetBlockingCallback(&_handle_blocking_call, <void*>cancel)
+    ).decode("utf-8")
+    if error:
+        warnings.warn(
+            f"Failed to set SIGINT handler: {error}",
+            RuntimeWarning,
+        )
+
+    try:
+        return func(*args, **kwargs)
+    except BaseException as e:
+        with _blocking_lock:
+            if _blocking_exc:
+                exc = _blocking_exc
+                _blocking_exc = None
+                raise e from exc[1].with_traceback(exc[2])
+        raise e
+    finally:
+        # Restore the signal handler
+        error = bytes(CClearBlockingCallback()).decode("utf-8")
+        if error:
+            warnings.warn(
+                f"Failed to restore SIGINT handler: {error}",
+                RuntimeWarning,
+            )
+        with _blocking_lock:
+            if _blocking_exc:
+                exc = _blocking_exc
+                _blocking_exc = None
+                raise exc[1].with_traceback(exc[2]) from KeyboardInterrupt
+
+
+if os.name != "nt":
+    # https://github.com/apache/arrow-adbc/issues/1522
+    _blocking_call = _blocking_call_impl
+else:
+    def _blocking_call(func, args, kwargs, cancel):
+        return func(*args, **kwargs)
+
+
+
+cdef void _handle_blocking_call(void* c_cancel) noexcept nogil:
+    with gil:
+        try:
+            cancel = <object> c_cancel
+            cancel()
+        except:
+            with _blocking_lock:
+                global _blocking_exc
+                _blocking_exc = sys.exc_info()
diff --git a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py 
b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
index 1e86144c..ee6f318d 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
+++ b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
@@ -55,6 +55,7 @@ else:
 import adbc_driver_manager
 
 from . import _lib, _reader
+from ._lib import _blocking_call
 
 if typing.TYPE_CHECKING:
     import pandas
@@ -677,9 +678,12 @@ class Cursor(_Closeable):
             parameters, which will each be bound in turn).
         """
         self._prepare_execute(operation, parameters)
-        handle, self._rowcount = self._stmt.execute_query()
+
+        handle, self._rowcount = _blocking_call(
+            self._stmt.execute_query, (), {}, self._stmt.cancel
+        )
         self._results = _RowIterator(
-            _reader.AdbcRecordBatchReader._import_from_c(handle.address)
+            self._stmt, 
_reader.AdbcRecordBatchReader._import_from_c(handle.address)
         )
 
     def executemany(self, operation: Union[bytes, str], seq_of_parameters) -> 
None:
@@ -991,7 +995,7 @@ class Cursor(_Closeable):
         handle = self._conn._conn.read_partition(partition)
         self._rowcount = -1
         self._results = _RowIterator(
-            pyarrow.RecordBatchReader._import_from_c(handle.address)
+            self._stmt, 
pyarrow.RecordBatchReader._import_from_c(handle.address)
         )
 
     @property
@@ -1095,7 +1099,8 @@ class Cursor(_Closeable):
 class _RowIterator(_Closeable):
     """Track state needed to iterate over the result set."""
 
-    def __init__(self, reader: pyarrow.RecordBatchReader) -> None:
+    def __init__(self, stmt, reader: pyarrow.RecordBatchReader) -> None:
+        self._stmt = stmt
         self._reader = reader
         self._current_batch = None
         self._next_row = 0
@@ -1118,7 +1123,9 @@ class _RowIterator(_Closeable):
         if self._current_batch is None or self._next_row >= 
len(self._current_batch):
             try:
                 while True:
-                    self._current_batch = self._reader.read_next_batch()
+                    self._current_batch = _blocking_call(
+                        self._reader.read_next_batch, (), {}, self._stmt.cancel
+                    )
                     if self._current_batch.num_rows > 0:
                         break
                 self._next_row = 0
diff --git a/python/adbc_driver_manager/pyproject.toml 
b/python/adbc_driver_manager/pyproject.toml
index 0a03fa3f..d2db1f10 100644
--- a/python/adbc_driver_manager/pyproject.toml
+++ b/python/adbc_driver_manager/pyproject.toml
@@ -23,6 +23,7 @@ license = {text = "Apache-2.0"}
 readme = "README.md"
 requires-python = ">=3.9"
 dynamic = ["version"]
+dependencies = ["typing-extensions"]
 
 [project.optional-dependencies]
 dbapi = ["pandas", "pyarrow>=8.0.0"]
diff --git a/python/adbc_driver_manager/setup.py 
b/python/adbc_driver_manager/setup.py
index bbec1a01..1b6f1026 100644
--- a/python/adbc_driver_manager/setup.py
+++ b/python/adbc_driver_manager/setup.py
@@ -76,6 +76,8 @@ build_type = os.environ.get("ADBC_BUILD_TYPE", "release")
 
 if sys.platform == "win32":
     extra_compile_args = ["/std:c++17", "/DADBC_EXPORTING"]
+    if build_type == "debug":
+        extra_compile_args.extend(["/DEBUG:FULL"])
 else:
     extra_compile_args = ["-std=c++17"]
     if build_type == "debug":
@@ -93,6 +95,7 @@ setup(
             
include_dirs=[str(source_root.joinpath("adbc_driver_manager").resolve())],
             language="c++",
             sources=[
+                "adbc_driver_manager/_blocking_impl.cc",
                 "adbc_driver_manager/_lib.pyx",
                 "adbc_driver_manager/adbc_driver_manager.cc",
             ],
diff --git a/python/adbc_driver_manager/tests/test_blocking.py 
b/python/adbc_driver_manager/tests/test_blocking.py
new file mode 100644
index 00000000..ea700773
--- /dev/null
+++ b/python/adbc_driver_manager/tests/test_blocking.py
@@ -0,0 +1,141 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Direct tests of the SIGINT handler.
+
+Higher-level testing of SIGINT during queries appears to be flaky in CI due to
+having to send the signal, so this tests the handler itself instead.
+"""
+
+import os
+import signal
+import threading
+import time
+
+import pytest
+
+from adbc_driver_manager import _lib
+
+# https://github.com/apache/arrow-adbc/issues/1522
+# It works fine on the normal Windows builds, but not under the Conda builds
+# where there is an unexplained/unreplicable crash, and so for now this is
+# disabled on Windows
+pytestmark = pytest.mark.skipif(os.name == "nt", reason="Disabled on Windows")
+
+
+def _send_sigint():
+    # Windows behavior is different
+    # https://stackoverflow.com/questions/35772001
+    if os.name == "nt":
+        os.kill(os.getpid(), signal.CTRL_C_EVENT)
+    else:
+        os.kill(os.getpid(), signal.SIGINT)
+
+
+def _blocking(event):
+    _send_sigint()
+    event.wait()
+
+
+def test_sigint_fires():
+    # Run the thing that fires SIGINT itself as the "blocking" call
+    event = threading.Event()
+
+    def _cancel():
+        event.set()
+
+    _lib._blocking_call(_blocking, (event,), {}, _cancel)
+
+
+def test_handler_restored():
+    event = threading.Event()
+    _lib._blocking_call(_blocking, (event,), {}, event.set)
+
+    # After it returns, this should raise KeyboardInterrupt like usual
+    with pytest.raises(KeyboardInterrupt):
+        _blocking(event)
+        # Needed on Windows so the handler runs before we exit the block (we
+        # won't sleep for the full time)
+        time.sleep(60)
+
+
+def test_args_return():
+    def _blocking(a, *, b):
+        return a, b
+
+    assert _lib._blocking_call(
+        _blocking,
+        (1,),
+        {"b": 2},
+        lambda: None,
+    ) == (1, 2)
+
+
+def test_blocking_raise():
+    def _blocking():
+        raise ValueError("expected error")
+
+    with pytest.raises(ValueError, match="expected error"):
+        _lib._blocking_call(_blocking, (), {}, lambda: None)
+
+
+def test_cancel_raise():
+    event = threading.Event()
+
+    def _cancel():
+        event.set()
+        raise ValueError("expected error")
+
+    with pytest.raises(ValueError, match="expected error"):
+        _lib._blocking_call(_blocking, (event,), {}, _cancel)
+
+
+def test_both_raise():
+    event = threading.Event()
+
+    def _blocking(event):
+        _send_sigint()
+        event.wait()
+        raise ValueError("expected error 1")
+
+    def _cancel():
+        event.set()
+        raise ValueError("expected error 2")
+
+    with pytest.raises(ValueError, match="expected error 1") as excinfo:
+        _lib._blocking_call(_blocking, (event,), {}, _cancel)
+    assert excinfo.value.__cause__ is not None
+    with pytest.raises(ValueError, match="expected error 2"):
+        raise excinfo.value.__cause__
+
+
+def test_nested():
+    # To be clear, don't ever do this.
+    event = threading.Event()
+
+    def _wrap_blocking():
+        _lib._blocking_call(_blocking, (event,), {}, event.set)
+
+    _lib._blocking_call(_wrap_blocking, (), {}, lambda: None)
+
+    # The original handler should be restored
+    with pytest.raises(KeyboardInterrupt):
+        _send_sigint()
+        # Needed on Windows so the handler runs before we exit the block (we
+        # won't sleep for the full time)
+        time.sleep(60)

Reply via email to