This is an automated email from the ASF dual-hosted git repository.

raulcd pushed a commit to branch maint-15.0.x
in repository https://gitbox.apache.org/repos/asf/arrow.git

commit dec703aa28347d5e8f6d568fb80bfb4a4c6aa79a
Author: Tom Jarosz <[email protected]>
AuthorDate: Tue Jan 9 05:25:21 2024 -0800

    GH-39313: [Python] Fix race condition in _pandas_api#_check_import (#39314)
    
    ### Rationale for this change
    
    See:
    ```
        cdef inline bint _have_pandas_internal(self):
            if not self._tried_importing_pandas:
                self._check_import(raise_=False)
            return self._have_pandas
    ```
    
    The method `_check_import`:
    1) sets `_tried_importing_pandas` to true
    2) does some things which take time...
    3) sets `_have_pandas` to true (if we indeed do have pandas)
    
    Suppose thread 1 calls `_have_pandas_internal`. If thread 1 is at step 2 
while thread 2 calls `_have_pandas_internal`, `_have_pandas_internal` may 
incorrectly return False for thread 2 as thread 1 has set 
`_tried_importing_pandas` to true, but has not yet (but will) set 
`_have_pandas` to True. `_have_pandas_internal` will return True for thread 1.
    
    After my fix, `_have_pandas_internal` will not return an incorrect value in 
the scenario described above. It would instead result in a redundant, but (I 
believe) harmless, invocation of `_check_import`.
    
    ### What changes are included in this PR?
    
    Changes ordering of "trying to import pandas" and "recording that pandas 
import has been tried"
    
    ### Are these changes tested?
    yes, see test committed
    
    ### Are there any user-facing changes?
    
    This PR resolves a user-facing race condition 
https://github.com/apache/arrow/issues/39313
    * Closes: #39313
    
    Lead-authored-by: Thomas Jarosz <[email protected]>
    Co-authored-by: Antoine Pitrou <[email protected]>
    Signed-off-by: Antoine Pitrou <[email protected]>
---
 python/pyarrow/pandas-shim.pxi      | 22 ++++++++++-------
 python/pyarrow/tests/arrow_39313.py | 47 +++++++++++++++++++++++++++++++++++++
 python/pyarrow/tests/test_pandas.py |  6 +++++
 3 files changed, 67 insertions(+), 8 deletions(-)

diff --git a/python/pyarrow/pandas-shim.pxi b/python/pyarrow/pandas-shim.pxi
index 273575b779..0409e133ad 100644
--- a/python/pyarrow/pandas-shim.pxi
+++ b/python/pyarrow/pandas-shim.pxi
@@ -18,6 +18,7 @@
 # pandas lazy-loading API shim that reduces API call and import overhead
 
 import warnings
+from threading import Lock
 
 
 cdef class _PandasAPIShim(object):
@@ -34,12 +35,13 @@ cdef class _PandasAPIShim(object):
         object _pd, _types_api, _compat_module
         object _data_frame, _index, _series, _categorical_type
         object _datetimetz_type, _extension_array, _extension_dtype
-        object _array_like_types, _is_extension_array_dtype
+        object _array_like_types, _is_extension_array_dtype, _lock
         bint has_sparse
         bint _pd024
         bint _is_v1, _is_ge_v21
 
     def __init__(self):
+        self._lock = Lock()
         self._tried_importing_pandas = False
         self._have_pandas = 0
 
@@ -96,13 +98,17 @@ cdef class _PandasAPIShim(object):
         self.has_sparse = False
 
     cdef inline _check_import(self, bint raise_=True):
-        if self._tried_importing_pandas:
-            if not self._have_pandas and raise_:
-                self._import_pandas(raise_)
-            return
-
-        self._tried_importing_pandas = True
-        self._import_pandas(raise_)
+        if not self._tried_importing_pandas:
+            with self._lock:
+                if not self._tried_importing_pandas:
+                    try:
+                        self._import_pandas(raise_)
+                    finally:
+                        self._tried_importing_pandas = True
+                    return
+
+        if not self._have_pandas and raise_:
+            self._import_pandas(raise_)
 
     def series(self, *args, **kwargs):
         self._check_import()
diff --git a/python/pyarrow/tests/arrow_39313.py 
b/python/pyarrow/tests/arrow_39313.py
new file mode 100644
index 0000000000..1e769f49d9
--- /dev/null
+++ b/python/pyarrow/tests/arrow_39313.py
@@ -0,0 +1,47 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# This file is called from a test in test_pandas.py.
+
+from threading import Thread
+
+import pandas as pd
+from pyarrow.pandas_compat import _pandas_api
+
+if __name__ == "__main__":
+    wait = True
+    num_threads = 10
+    df = pd.DataFrame()
+    results = []
+
+    def rc():
+        while wait:
+            pass
+        results.append(_pandas_api.is_data_frame(df))
+
+    threads = [Thread(target=rc) for _ in range(num_threads)]
+
+    for t in threads:
+        t.start()
+
+    wait = False
+
+    for t in threads:
+        t.join()
+
+    assert len(results) == num_threads
+    assert all(results), "`is_data_frame` returned False when given a 
DataFrame"
diff --git a/python/pyarrow/tests/test_pandas.py 
b/python/pyarrow/tests/test_pandas.py
index 3353bebce7..d15ee82d5d 100644
--- a/python/pyarrow/tests/test_pandas.py
+++ b/python/pyarrow/tests/test_pandas.py
@@ -34,6 +34,7 @@ import pytest
 from pyarrow.pandas_compat import get_logical_type, _pandas_api
 from pyarrow.tests.util import invoke_script, random_ascii, rands
 import pyarrow.tests.strategies as past
+import pyarrow.tests.util as test_util
 from pyarrow.vendored.version import Version
 
 import pyarrow as pa
@@ -5008,3 +5009,8 @@ def test_nested_chunking_valid():
     schema = pa.schema([("maps", map_type)])
     roundtrip(pd.DataFrame({"maps": [map_of_los, map_of_los, map_of_los]}),
               schema=schema)
+
+
+def test_is_data_frame_race_condition():
+    # See https://github.com/apache/arrow/issues/39313
+    test_util.invoke_script('arrow_39313.py')

Reply via email to