https://github.com/python/cpython/commit/1b6bef8086e46e189cb9c4e46ac2945bc7848ed6
commit: 1b6bef8086e46e189cb9c4e46ac2945bc7848ed6
branch: main
author: Tomasz Pytel <tompy...@gmail.com>
committer: kumaraditya303 <kumaradi...@python.org>
date: 2025-02-19T15:42:45+05:30
summary:

gh-129107: make `bytearray` iterator thread safe (#130096)

Co-authored-by: Kumar Aditya <kumaradi...@python.org>

files:
A 
Misc/NEWS.d/next/Core_and_Builtins/2025-02-13-20-42-53.gh-issue-129107._olg-L.rst
M Lib/test/test_bytes.py
M Objects/bytearrayobject.c

diff --git a/Lib/test/test_bytes.py b/Lib/test/test_bytes.py
index 9ec5f4525c5d32..907988c27eba0a 100644
--- a/Lib/test/test_bytes.py
+++ b/Lib/test/test_bytes.py
@@ -2455,9 +2455,6 @@ def check(funcs, a=None, *args):
             with threading_helper.start_threads(threads):
                 pass
 
-            for thread in threads:
-                threading_helper.join_thread(thread)
-
         # hard errors
 
         check([clear] + [reduce] * 10)
@@ -2519,6 +2516,44 @@ def check(funcs, a=None, *args):
         check([clear] + [upper] * 10, bytearray(b'a' * 0x400000))
         check([clear] + [zfill] * 10, bytearray(b'1' * 0x200000))
 
+    @unittest.skipUnless(support.Py_GIL_DISABLED, 'this test can only possibly 
fail with GIL disabled')
+    @threading_helper.reap_threads
+    @threading_helper.requires_working_threading()
+    def test_free_threading_bytearrayiter(self):
+        # Non-deterministic but good chance to fail if bytearrayiter is not 
free-threading safe.
+        # We are fishing for a "Assertion failed: object has negative ref 
count" and tsan races.
+
+        def iter_next(b, it):
+            b.wait()
+            list(it)
+
+        def iter_reduce(b, it):
+            b.wait()
+            it.__reduce__()
+
+        def iter_setstate(b, it):
+            b.wait()
+            it.__setstate__(0)
+
+        def check(funcs, it):
+            barrier = threading.Barrier(len(funcs))
+            threads = []
+
+            for func in funcs:
+                thread = threading.Thread(target=func, args=(barrier, it))
+
+                threads.append(thread)
+
+            with threading_helper.start_threads(threads):
+                pass
+
+        for _ in range(10):
+            ba = bytearray(b'0' * 0x4000)  # this is a load-bearing variable, 
do not remove
+
+            check([iter_next] * 10, iter(ba))
+            check([iter_next] + [iter_reduce] * 10, iter(ba))  # for tsan
+            check([iter_next] + [iter_setstate] * 10, iter(ba))  # for tsan
+
 
 if __name__ == "__main__":
     unittest.main()
diff --git 
a/Misc/NEWS.d/next/Core_and_Builtins/2025-02-13-20-42-53.gh-issue-129107._olg-L.rst
 
b/Misc/NEWS.d/next/Core_and_Builtins/2025-02-13-20-42-53.gh-issue-129107._olg-L.rst
new file mode 100644
index 00000000000000..7ae9cd31e24ff5
--- /dev/null
+++ 
b/Misc/NEWS.d/next/Core_and_Builtins/2025-02-13-20-42-53.gh-issue-129107._olg-L.rst
@@ -0,0 +1 @@
+Make :class:`bytearray` iterator safe under :term:`free threading`.
diff --git a/Objects/bytearrayobject.c b/Objects/bytearrayobject.c
index 30cc05a1280dd1..f2cfd4aed3979f 100644
--- a/Objects/bytearrayobject.c
+++ b/Objects/bytearrayobject.c
@@ -2856,22 +2856,34 @@ static PyObject *
 bytearrayiter_next(PyObject *self)
 {
     bytesiterobject *it = _bytesiterobject_CAST(self);
-    PyByteArrayObject *seq;
+    int val;
 
     assert(it != NULL);
-    seq = it->it_seq;
-    if (seq == NULL)
+    Py_ssize_t index = FT_ATOMIC_LOAD_SSIZE_RELAXED(it->it_index);
+    if (index < 0) {
         return NULL;
+    }
+    PyByteArrayObject *seq = it->it_seq;
     assert(PyByteArray_Check(seq));
 
-    if (it->it_index < PyByteArray_GET_SIZE(seq)) {
-        return _PyLong_FromUnsignedChar(
-            (unsigned char)PyByteArray_AS_STRING(seq)[it->it_index++]);
+    Py_BEGIN_CRITICAL_SECTION(seq);
+    if (index < Py_SIZE(seq)) {
+        val = (unsigned char)PyByteArray_AS_STRING(seq)[index];
+    }
+    else {
+        val = -1;
     }
+    Py_END_CRITICAL_SECTION();
 
-    it->it_seq = NULL;
-    Py_DECREF(seq);
-    return NULL;
+    if (val == -1) {
+        FT_ATOMIC_STORE_SSIZE_RELAXED(it->it_index, -1);
+#ifndef Py_GIL_DISABLED
+        Py_CLEAR(it->it_seq);
+#endif
+        return NULL;
+    }
+    FT_ATOMIC_STORE_SSIZE_RELAXED(it->it_index, index + 1);
+    return _PyLong_FromUnsignedChar((unsigned char)val);
 }
 
 static PyObject *
@@ -2879,8 +2891,9 @@ bytearrayiter_length_hint(PyObject *self, PyObject 
*Py_UNUSED(ignored))
 {
     bytesiterobject *it = _bytesiterobject_CAST(self);
     Py_ssize_t len = 0;
-    if (it->it_seq) {
-        len = PyByteArray_GET_SIZE(it->it_seq) - it->it_index;
+    Py_ssize_t index = FT_ATOMIC_LOAD_SSIZE_RELAXED(it->it_index);
+    if (index >= 0) {
+        len = PyByteArray_GET_SIZE(it->it_seq) - index;
         if (len < 0) {
             len = 0;
         }
@@ -2900,27 +2913,33 @@ bytearrayiter_reduce(PyObject *self, PyObject 
*Py_UNUSED(ignored))
      * call must be before access of iterator pointers.
      * see issue #101765 */
     bytesiterobject *it = _bytesiterobject_CAST(self);
-    if (it->it_seq != NULL) {
-        return Py_BuildValue("N(O)n", iter, it->it_seq, it->it_index);
-    } else {
-        return Py_BuildValue("N(())", iter);
+    Py_ssize_t index = FT_ATOMIC_LOAD_SSIZE_RELAXED(it->it_index);
+    if (index >= 0) {
+        return Py_BuildValue("N(O)n", iter, it->it_seq, index);
     }
+    return Py_BuildValue("N(())", iter);
 }
 
 static PyObject *
 bytearrayiter_setstate(PyObject *self, PyObject *state)
 {
     Py_ssize_t index = PyLong_AsSsize_t(state);
-    if (index == -1 && PyErr_Occurred())
+    if (index == -1 && PyErr_Occurred()) {
         return NULL;
+    }
 
     bytesiterobject *it = _bytesiterobject_CAST(self);
-    if (it->it_seq != NULL) {
-        if (index < 0)
-            index = 0;
-        else if (index > PyByteArray_GET_SIZE(it->it_seq))
-            index = PyByteArray_GET_SIZE(it->it_seq); /* iterator exhausted */
-        it->it_index = index;
+    if (FT_ATOMIC_LOAD_SSIZE_RELAXED(it->it_index) >= 0) {
+        if (index < -1) {
+            index = -1;
+        }
+        else {
+            Py_ssize_t size = PyByteArray_GET_SIZE(it->it_seq);
+            if (index > size) {
+                index = size; /* iterator at end */
+            }
+        }
+        FT_ATOMIC_STORE_SSIZE_RELAXED(it->it_index, index);
     }
     Py_RETURN_NONE;
 }
@@ -2982,7 +3001,7 @@ bytearray_iter(PyObject *seq)
     it = PyObject_GC_New(bytesiterobject, &PyByteArrayIter_Type);
     if (it == NULL)
         return NULL;
-    it->it_index = 0;
+    it->it_index = 0;  // -1 indicates exhausted
     it->it_seq = (PyByteArrayObject *)Py_NewRef(seq);
     _PyObject_GC_TRACK(it);
     return (PyObject *)it;

_______________________________________________
Python-checkins mailing list -- python-checkins@python.org
To unsubscribe send an email to python-checkins-le...@python.org
https://mail.python.org/mailman3/lists/python-checkins.python.org/
Member address: arch...@mail-archive.com

Reply via email to