This is an automated email from the ASF dual-hosted git repository.
apitrou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new d076c69e81 GH-38676: [Python] Fix potential deadlock when CSV reading
errors out (#38713)
d076c69e81 is described below
commit d076c69e81e5d331bae214a3cf9fabedb17752fa
Author: Antoine Pitrou <[email protected]>
AuthorDate: Wed Nov 15 16:15:42 2023 +0100
GH-38676: [Python] Fix potential deadlock when CSV reading errors out
(#38713)
### Rationale for this change
A deadlock can happen in a C++ destructor in the following case:
* the C++ destructor is called from Python, holding the GIL
* the C++ destructor waits for a threaded task to finish
* the threaded task has invoked some Python code which is waiting to
acquire the GIL
### What changes are included in this PR?
To reliably present such a deadlock, introduce `std::shared_ptr` and
`std::unique_ptr` wrappers that release the GIL when deallocating the embedded
pointer.
### Are these changes tested?
Yes.
### Are there any user-facing changes?
No.
* Closes: #38676
Authored-by: Antoine Pitrou <[email protected]>
Signed-off-by: Antoine Pitrou <[email protected]>
---
python/pyarrow/_csv.pyx | 5 ++-
python/pyarrow/_dataset.pxd | 8 ++---
python/pyarrow/_dataset.pyx | 4 +--
python/pyarrow/_parquet.pyx | 6 ++--
python/pyarrow/includes/libarrow_python.pxd | 8 +++++
python/pyarrow/ipc.pxi | 2 +-
python/pyarrow/lib.pxd | 4 +--
python/pyarrow/src/arrow/python/common.h | 55 ++++++++++++++++++++++++++---
python/pyarrow/tests/test_csv.py | 21 +++++++++++
9 files changed, 93 insertions(+), 20 deletions(-)
diff --git a/python/pyarrow/_csv.pyx b/python/pyarrow/_csv.pyx
index e532d8d8ab..508488c0c3 100644
--- a/python/pyarrow/_csv.pyx
+++ b/python/pyarrow/_csv.pyx
@@ -26,8 +26,7 @@ from collections.abc import Mapping
from pyarrow.includes.common cimport *
from pyarrow.includes.libarrow cimport *
-from pyarrow.includes.libarrow_python cimport (MakeInvalidRowHandler,
- PyInvalidRowCallback)
+from pyarrow.includes.libarrow_python cimport *
from pyarrow.lib cimport (check_status, Field, MemoryPool, Schema,
RecordBatchReader, ensure_type,
maybe_unbox_memory_pool, get_input_stream,
@@ -1251,7 +1250,7 @@ def read_csv(input_file, read_options=None,
parse_options=None,
CCSVParseOptions c_parse_options
CCSVConvertOptions c_convert_options
CIOContext io_context
- shared_ptr[CCSVReader] reader
+ SharedPtrNoGIL[CCSVReader] reader
shared_ptr[CTable] table
_get_reader(input_file, read_options, &stream)
diff --git a/python/pyarrow/_dataset.pxd b/python/pyarrow/_dataset.pxd
index 210e555800..bee9fc1f09 100644
--- a/python/pyarrow/_dataset.pxd
+++ b/python/pyarrow/_dataset.pxd
@@ -31,7 +31,7 @@ cdef CFileSource _make_file_source(object file, FileSystem
filesystem=*)
cdef class DatasetFactory(_Weakrefable):
cdef:
- shared_ptr[CDatasetFactory] wrapped
+ SharedPtrNoGIL[CDatasetFactory] wrapped
CDatasetFactory* factory
cdef init(self, const shared_ptr[CDatasetFactory]& sp)
@@ -45,7 +45,7 @@ cdef class DatasetFactory(_Weakrefable):
cdef class Dataset(_Weakrefable):
cdef:
- shared_ptr[CDataset] wrapped
+ SharedPtrNoGIL[CDataset] wrapped
CDataset* dataset
public dict _scan_options
@@ -59,7 +59,7 @@ cdef class Dataset(_Weakrefable):
cdef class Scanner(_Weakrefable):
cdef:
- shared_ptr[CScanner] wrapped
+ SharedPtrNoGIL[CScanner] wrapped
CScanner* scanner
cdef void init(self, const shared_ptr[CScanner]& sp)
@@ -122,7 +122,7 @@ cdef class FileWriteOptions(_Weakrefable):
cdef class Fragment(_Weakrefable):
cdef:
- shared_ptr[CFragment] wrapped
+ SharedPtrNoGIL[CFragment] wrapped
CFragment* fragment
cdef void init(self, const shared_ptr[CFragment]& sp)
diff --git a/python/pyarrow/_dataset.pyx b/python/pyarrow/_dataset.pyx
index 48ee676915..d7d69965d0 100644
--- a/python/pyarrow/_dataset.pyx
+++ b/python/pyarrow/_dataset.pyx
@@ -3227,7 +3227,7 @@ cdef class RecordBatchIterator(_Weakrefable):
object iterator_owner
# Iterator is a non-POD type and Cython uses offsetof, leading
# to a compiler warning unless wrapped like so
- shared_ptr[CRecordBatchIterator] iterator
+ SharedPtrNoGIL[CRecordBatchIterator] iterator
def __init__(self):
_forbid_instantiation(self.__class__, subclasses_instead=False)
@@ -3273,7 +3273,7 @@ cdef class TaggedRecordBatchIterator(_Weakrefable):
"""An iterator over a sequence of record batches with fragments."""
cdef:
object iterator_owner
- shared_ptr[CTaggedRecordBatchIterator] iterator
+ SharedPtrNoGIL[CTaggedRecordBatchIterator] iterator
def __init__(self):
_forbid_instantiation(self.__class__, subclasses_instead=False)
diff --git a/python/pyarrow/_parquet.pyx b/python/pyarrow/_parquet.pyx
index 48091367b2..089ed7c75c 100644
--- a/python/pyarrow/_parquet.pyx
+++ b/python/pyarrow/_parquet.pyx
@@ -24,6 +24,7 @@ import warnings
from cython.operator cimport dereference as deref
from pyarrow.includes.common cimport *
from pyarrow.includes.libarrow cimport *
+from pyarrow.includes.libarrow_python cimport *
from pyarrow.lib cimport (_Weakrefable, Buffer, Schema,
check_status,
MemoryPool, maybe_unbox_memory_pool,
@@ -1165,7 +1166,7 @@ cdef class ParquetReader(_Weakrefable):
cdef:
object source
CMemoryPool* pool
- unique_ptr[FileReader] reader
+ UniquePtrNoGIL[FileReader] reader
FileMetaData _metadata
shared_ptr[CRandomAccessFile] rd_handle
@@ -1334,7 +1335,7 @@ cdef class ParquetReader(_Weakrefable):
vector[int] c_row_groups
vector[int] c_column_indices
shared_ptr[CRecordBatch] record_batch
- unique_ptr[CRecordBatchReader] recordbatchreader
+ UniquePtrNoGIL[CRecordBatchReader] recordbatchreader
self.set_batch_size(batch_size)
@@ -1366,7 +1367,6 @@ cdef class ParquetReader(_Weakrefable):
check_status(
recordbatchreader.get().ReadNext(&record_batch)
)
-
if record_batch.get() == NULL:
break
diff --git a/python/pyarrow/includes/libarrow_python.pxd
b/python/pyarrow/includes/libarrow_python.pxd
index 4d109fc660..b8a3041796 100644
--- a/python/pyarrow/includes/libarrow_python.pxd
+++ b/python/pyarrow/includes/libarrow_python.pxd
@@ -261,6 +261,14 @@ cdef extern from "arrow/python/common.h" namespace
"arrow::py":
void RestorePyError(const CStatus& status) except *
+cdef extern from "arrow/python/common.h" namespace "arrow::py" nogil:
+ cdef cppclass SharedPtrNoGIL[T](shared_ptr[T]):
+ # This looks like the only way to satsify both Cython 2 and Cython 3
+ SharedPtrNoGIL& operator=(...)
+ cdef cppclass UniquePtrNoGIL[T, DELETER=*](unique_ptr[T, DELETER]):
+ UniquePtrNoGIL& operator=(...)
+
+
cdef extern from "arrow/python/inference.h" namespace "arrow::py":
c_bool IsPyBool(object o)
c_bool IsPyInt(object o)
diff --git a/python/pyarrow/ipc.pxi b/python/pyarrow/ipc.pxi
index fcb9eb729e..5d20a4f8b7 100644
--- a/python/pyarrow/ipc.pxi
+++ b/python/pyarrow/ipc.pxi
@@ -977,7 +977,7 @@ cdef
_wrap_record_batch_with_metadata(CRecordBatchWithMetadata c):
cdef class _RecordBatchFileReader(_Weakrefable):
cdef:
- shared_ptr[CRecordBatchFileReader] reader
+ SharedPtrNoGIL[CRecordBatchFileReader] reader
shared_ptr[CRandomAccessFile] file
CIpcReadOptions options
diff --git a/python/pyarrow/lib.pxd b/python/pyarrow/lib.pxd
index 63ebe6aea8..ae197eca1c 100644
--- a/python/pyarrow/lib.pxd
+++ b/python/pyarrow/lib.pxd
@@ -552,12 +552,12 @@ cdef class CompressedOutputStream(NativeFile):
cdef class _CRecordBatchWriter(_Weakrefable):
cdef:
- shared_ptr[CRecordBatchWriter] writer
+ SharedPtrNoGIL[CRecordBatchWriter] writer
cdef class RecordBatchReader(_Weakrefable):
cdef:
- shared_ptr[CRecordBatchReader] reader
+ SharedPtrNoGIL[CRecordBatchReader] reader
cdef class Codec(_Weakrefable):
diff --git a/python/pyarrow/src/arrow/python/common.h
b/python/pyarrow/src/arrow/python/common.h
index bc567ef78e..4a7886695e 100644
--- a/python/pyarrow/src/arrow/python/common.h
+++ b/python/pyarrow/src/arrow/python/common.h
@@ -19,6 +19,7 @@
#include <functional>
#include <memory>
+#include <optional>
#include <utility>
#include "arrow/buffer.h"
@@ -134,13 +135,15 @@ class ARROW_PYTHON_EXPORT PyAcquireGIL {
// A RAII-style helper that releases the GIL until the end of a lexical block
class ARROW_PYTHON_EXPORT PyReleaseGIL {
public:
- PyReleaseGIL() { saved_state_ = PyEval_SaveThread(); }
-
- ~PyReleaseGIL() { PyEval_RestoreThread(saved_state_); }
+ PyReleaseGIL() : ptr_(PyEval_SaveThread(), &unique_ptr_deleter) {}
private:
- PyThreadState* saved_state_;
- ARROW_DISALLOW_COPY_AND_ASSIGN(PyReleaseGIL);
+ static void unique_ptr_deleter(PyThreadState* state) {
+ if (state) {
+ PyEval_RestoreThread(state);
+ }
+ }
+ std::unique_ptr<PyThreadState, decltype(&unique_ptr_deleter)> ptr_;
};
// A helper to call safely into the Python interpreter from arbitrary C++ code.
@@ -238,6 +241,48 @@ class ARROW_PYTHON_EXPORT OwnedRefNoGIL : public OwnedRef {
}
};
+template <template <typename...> typename SmartPtr, typename... Ts>
+class SmartPtrNoGIL : public SmartPtr<Ts...> {
+ using Base = SmartPtr<Ts...>;
+
+ public:
+ template <typename... Args>
+ SmartPtrNoGIL(Args&&... args) : Base(std::forward<Args>(args)...) {}
+
+ ~SmartPtrNoGIL() { reset(); }
+
+ template <typename... Args>
+ void reset(Args&&... args) {
+ auto release_guard = optional_gil_release();
+ Base::reset(std::forward<Args>(args)...);
+ }
+
+ template <typename V>
+ SmartPtrNoGIL& operator=(V&& v) {
+ auto release_guard = optional_gil_release();
+ Base::operator=(std::forward<V>(v));
+ return *this;
+ }
+
+ private:
+ // Only release the GIL if we own an object *and* the Python runtime is
+ // valid *and* the GIL is held.
+ std::optional<PyReleaseGIL> optional_gil_release() const {
+ if (this->get() != nullptr && Py_IsInitialized() && PyGILState_Check()) {
+ return PyReleaseGIL();
+ }
+ return {};
+ }
+};
+
+/// \brief A std::shared_ptr<T, ...> subclass that releases the GIL when
destroying T
+template <typename... Ts>
+using SharedPtrNoGIL = SmartPtrNoGIL<std::shared_ptr, Ts...>;
+
+/// \brief A std::unique_ptr<T, ...> subclass that releases the GIL when
destroying T
+template <typename... Ts>
+using UniquePtrNoGIL = SmartPtrNoGIL<std::unique_ptr, Ts...>;
+
template <typename Fn>
struct BoundFunction;
diff --git a/python/pyarrow/tests/test_csv.py b/python/pyarrow/tests/test_csv.py
index afc5380b75..31f24187e3 100644
--- a/python/pyarrow/tests/test_csv.py
+++ b/python/pyarrow/tests/test_csv.py
@@ -1970,3 +1970,24 @@ def test_write_csv_decimal(tmpdir, type_factory):
out = read_csv(tmpdir / "out.csv")
assert out.column('col').cast(type) == table.column('col')
+
+
+def test_read_csv_gil_deadlock():
+ # GH-38676
+ # This test depends on several preconditions:
+ # - the CSV input is a Python file object
+ # - reading the CSV file produces an error
+ data = b"a,b,c"
+
+ class MyBytesIO(io.BytesIO):
+ def read(self, *args):
+ time.sleep(0.001)
+ return super().read(*args)
+
+ def readinto(self, *args):
+ time.sleep(0.001)
+ return super().readinto(*args)
+
+ for i in range(20):
+ with pytest.raises(pa.ArrowInvalid):
+ read_csv(MyBytesIO(data))