https://github.com/python/cpython/commit/cf2532b39d099e004d1c07b2d0fcc46567b68e75
commit: cf2532b39d099e004d1c07b2d0fcc46567b68e75
branch: 3.12
author: Raymond Hettinger <[email protected]>
committer: rhettinger <[email protected]>
date: 2024-10-08T20:16:18Z
summary:
[3.12] Tee of tee was not producing n independent iterators (gh-123884)
(gh-125153)
files:
A Misc/NEWS.d/next/Library/2024-09-24-22-38-51.gh-issue-123884.iEPTK4.rst
M Doc/library/itertools.rst
M Lib/test/test_itertools.py
M Modules/itertoolsmodule.c
diff --git a/Doc/library/itertools.rst b/Doc/library/itertools.rst
index 3fab46c3c0a5b4..047d805eda628a 100644
--- a/Doc/library/itertools.rst
+++ b/Doc/library/itertools.rst
@@ -676,24 +676,37 @@ loops that truncate the stream.
Roughly equivalent to::
def tee(iterable, n=2):
- iterator = iter(iterable)
- shared_link = [None, None]
- return tuple(_tee(iterator, shared_link) for _ in range(n))
-
- def _tee(iterator, link):
- try:
- while True:
- if link[1] is None:
- link[0] = next(iterator)
- link[1] = [None, None]
- value, link = link
- yield value
- except StopIteration:
- return
-
- Once a :func:`tee` has been created, the original *iterable* should not be
- used anywhere else; otherwise, the *iterable* could get advanced without
- the tee objects being informed.
+ if n < 0:
+ raise ValueError
+ if n == 0:
+ return ()
+ iterator = _tee(iterable)
+ result = [iterator]
+ for _ in range(n - 1):
+ result.append(_tee(iterator))
+ return tuple(result)
+
+ class _tee:
+
+ def __init__(self, iterable):
+ it = iter(iterable)
+ if isinstance(it, _tee):
+ self.iterator = it.iterator
+ self.link = it.link
+ else:
+ self.iterator = it
+ self.link = [None, None]
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ link = self.link
+ if link[1] is None:
+ link[0] = next(self.iterator)
+ link[1] = [None, None]
+ value, self.link = link
+ return value
``tee`` iterators are not threadsafe. A :exc:`RuntimeError` may be
raised when simultaneously using iterators returned by the same :func:`tee`
diff --git a/Lib/test/test_itertools.py b/Lib/test/test_itertools.py
index 3d20e70fc1b63f..b6404f4366ca0e 100644
--- a/Lib/test/test_itertools.py
+++ b/Lib/test/test_itertools.py
@@ -1612,10 +1612,11 @@ def test_tee(self):
self.assertEqual(len(result), n)
self.assertEqual([list(x) for x in result], [list('abc')]*n)
- # tee pass-through to copyable iterator
+ # tee objects are independent (see bug gh-123884)
a, b = tee('abc')
c, d = tee(a)
- self.assertTrue(a is c)
+ e, f = tee(c)
+ self.assertTrue(len({a, b, c, d, e, f}) == 6)
# test tee_new
t1, t2 = tee('abc')
@@ -2029,6 +2030,172 @@ def test_islice_recipe(self):
self.assertEqual(next(c), 3)
+ def test_tee_recipe(self):
+
+ # Begin tee() recipe ###########################################
+
+ def tee(iterable, n=2):
+ if n < 0:
+ raise ValueError
+ if n == 0:
+ return ()
+ iterator = _tee(iterable)
+ result = [iterator]
+ for _ in range(n - 1):
+ result.append(_tee(iterator))
+ return tuple(result)
+
+ class _tee:
+
+ def __init__(self, iterable):
+ it = iter(iterable)
+ if isinstance(it, _tee):
+ self.iterator = it.iterator
+ self.link = it.link
+ else:
+ self.iterator = it
+ self.link = [None, None]
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ link = self.link
+ if link[1] is None:
+ link[0] = next(self.iterator)
+ link[1] = [None, None]
+ value, self.link = link
+ return value
+
+ # End tee() recipe #############################################
+
+ n = 200
+
+ a, b = tee([]) # test empty iterator
+ self.assertEqual(list(a), [])
+ self.assertEqual(list(b), [])
+
+ a, b = tee(irange(n)) # test 100% interleaved
+ self.assertEqual(lzip(a,b), lzip(range(n), range(n)))
+
+ a, b = tee(irange(n)) # test 0% interleaved
+ self.assertEqual(list(a), list(range(n)))
+ self.assertEqual(list(b), list(range(n)))
+
+ a, b = tee(irange(n)) # test dealloc of leading iterator
+ for i in range(100):
+ self.assertEqual(next(a), i)
+ del a
+ self.assertEqual(list(b), list(range(n)))
+
+ a, b = tee(irange(n)) # test dealloc of trailing iterator
+ for i in range(100):
+ self.assertEqual(next(a), i)
+ del b
+ self.assertEqual(list(a), list(range(100, n)))
+
+ for j in range(5): # test randomly interleaved
+ order = [0]*n + [1]*n
+ random.shuffle(order)
+ lists = ([], [])
+ its = tee(irange(n))
+ for i in order:
+ value = next(its[i])
+ lists[i].append(value)
+ self.assertEqual(lists[0], list(range(n)))
+ self.assertEqual(lists[1], list(range(n)))
+
+ # test argument format checking
+ self.assertRaises(TypeError, tee)
+ self.assertRaises(TypeError, tee, 3)
+ self.assertRaises(TypeError, tee, [1,2], 'x')
+ self.assertRaises(TypeError, tee, [1,2], 3, 'x')
+
+ # tee object should be instantiable
+ a, b = tee('abc')
+ c = type(a)('def')
+ self.assertEqual(list(c), list('def'))
+
+ # test long-lagged and multi-way split
+ a, b, c = tee(range(2000), 3)
+ for i in range(100):
+ self.assertEqual(next(a), i)
+ self.assertEqual(list(b), list(range(2000)))
+ self.assertEqual([next(c), next(c)], list(range(2)))
+ self.assertEqual(list(a), list(range(100,2000)))
+ self.assertEqual(list(c), list(range(2,2000)))
+
+ # test invalid values of n
+ self.assertRaises(TypeError, tee, 'abc', 'invalid')
+ self.assertRaises(ValueError, tee, [], -1)
+
+ for n in range(5):
+ result = tee('abc', n)
+ self.assertEqual(type(result), tuple)
+ self.assertEqual(len(result), n)
+ self.assertEqual([list(x) for x in result], [list('abc')]*n)
+
+ # tee objects are independent (see bug gh-123884)
+ a, b = tee('abc')
+ c, d = tee(a)
+ e, f = tee(c)
+ self.assertTrue(len({a, b, c, d, e, f}) == 6)
+
+ # test tee_new
+ t1, t2 = tee('abc')
+ tnew = type(t1)
+ self.assertRaises(TypeError, tnew)
+ self.assertRaises(TypeError, tnew, 10)
+ t3 = tnew(t1)
+ self.assertTrue(list(t1) == list(t2) == list(t3) == list('abc'))
+
+ # test that tee objects are weak referencable
+ a, b = tee(range(10))
+ p = weakref.proxy(a)
+ self.assertEqual(getattr(p, '__class__'), type(b))
+ del a
+ gc.collect() # For PyPy or other GCs.
+ self.assertRaises(ReferenceError, getattr, p, '__class__')
+
+ ans = list('abc')
+ long_ans = list(range(10000))
+
+ # Tests not applicable to the tee() recipe
+ if False:
+ # check copy
+ a, b = tee('abc')
+ self.assertEqual(list(copy.copy(a)), ans)
+ self.assertEqual(list(copy.copy(b)), ans)
+ a, b = tee(list(range(10000)))
+ self.assertEqual(list(copy.copy(a)), long_ans)
+ self.assertEqual(list(copy.copy(b)), long_ans)
+
+ # check partially consumed copy
+ a, b = tee('abc')
+ take(2, a)
+ take(1, b)
+ self.assertEqual(list(copy.copy(a)), ans[2:])
+ self.assertEqual(list(copy.copy(b)), ans[1:])
+ self.assertEqual(list(a), ans[2:])
+ self.assertEqual(list(b), ans[1:])
+ a, b = tee(range(10000))
+ take(100, a)
+ take(60, b)
+ self.assertEqual(list(copy.copy(a)), long_ans[100:])
+ self.assertEqual(list(copy.copy(b)), long_ans[60:])
+ self.assertEqual(list(a), long_ans[100:])
+ self.assertEqual(list(b), long_ans[60:])
+
+ # Issue 13454: Crash when deleting backward iterator from tee()
+ forward, backward = tee(repeat(None, 2000)) # 20000000
+ try:
+ any(forward) # exhaust the iterator
+ del backward
+ except:
+ del forward, backward
+ raise
+
+
class TestGC(unittest.TestCase):
def makecycle(self, iterator, container):
diff --git
a/Misc/NEWS.d/next/Library/2024-09-24-22-38-51.gh-issue-123884.iEPTK4.rst
b/Misc/NEWS.d/next/Library/2024-09-24-22-38-51.gh-issue-123884.iEPTK4.rst
new file mode 100644
index 00000000000000..55f1d4b41125c3
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2024-09-24-22-38-51.gh-issue-123884.iEPTK4.rst
@@ -0,0 +1,4 @@
+Fixed bug in itertools.tee() handling of other tee inputs (a tee in a tee).
+The output now has the promised *n* independent new iterators. Formerly,
+the first iterator was identical (not independent) to the input iterator.
+This would sometimes give surprising results.
diff --git a/Modules/itertoolsmodule.c b/Modules/itertoolsmodule.c
index d42f9dd0768658..e87c753113563f 100644
--- a/Modules/itertoolsmodule.c
+++ b/Modules/itertoolsmodule.c
@@ -1137,7 +1137,7 @@ itertools_tee_impl(PyObject *module, PyObject *iterable,
Py_ssize_t n)
/*[clinic end generated code: output=1c64519cd859c2f0 input=c99a1472c425d66d]*/
{
Py_ssize_t i;
- PyObject *it, *copyable, *copyfunc, *result;
+ PyObject *it, *to, *result;
if (n < 0) {
PyErr_SetString(PyExc_ValueError, "n must be >= 0");
@@ -1154,41 +1154,24 @@ itertools_tee_impl(PyObject *module, PyObject
*iterable, Py_ssize_t n)
return NULL;
}
- if (_PyObject_LookupAttr(it, &_Py_ID(__copy__), ©func) < 0) {
- Py_DECREF(it);
+ (void)&_Py_ID(__copy__); // Retain a reference to __copy__
+ itertools_state *state = get_module_state(module);
+ to = tee_fromiterable(state, it);
+ Py_DECREF(it);
+ if (to == NULL) {
Py_DECREF(result);
return NULL;
}
- if (copyfunc != NULL) {
- copyable = it;
- }
- else {
- itertools_state *state = get_module_state(module);
- copyable = tee_fromiterable(state, it);
- Py_DECREF(it);
- if (copyable == NULL) {
- Py_DECREF(result);
- return NULL;
- }
- copyfunc = PyObject_GetAttr(copyable, &_Py_ID(__copy__));
- if (copyfunc == NULL) {
- Py_DECREF(copyable);
- Py_DECREF(result);
- return NULL;
- }
- }
- PyTuple_SET_ITEM(result, 0, copyable);
+ PyTuple_SET_ITEM(result, 0, to);
for (i = 1; i < n; i++) {
- copyable = _PyObject_CallNoArgs(copyfunc);
- if (copyable == NULL) {
- Py_DECREF(copyfunc);
+ to = tee_copy((teeobject *)to, NULL);
+ if (to == NULL) {
Py_DECREF(result);
return NULL;
}
- PyTuple_SET_ITEM(result, i, copyable);
+ PyTuple_SET_ITEM(result, i, to);
}
- Py_DECREF(copyfunc);
return result;
}
_______________________________________________
Python-checkins mailing list -- [email protected]
To unsubscribe send an email to [email protected]
https://mail.python.org/mailman3/lists/python-checkins.python.org/
Member address: [email protected]