Author: Omer Katz <[email protected]>
Branch: py3.3
Changeset: r80858:5a174924685f
Date: 2015-11-23 12:56 +0200
http://bitbucket.org/pypy/pypy/changeset/5a174924685f/

Log:    Added a __reduce__ method to the accumulate iterator.

diff --git a/pypy/module/itertools/interp_itertools.py 
b/pypy/module/itertools/interp_itertools.py
--- a/pypy/module/itertools/interp_itertools.py
+++ b/pypy/module/itertools/interp_itertools.py
@@ -1217,6 +1217,13 @@
             self.w_total = self.space.call_function(self.w_func, self.w_total, 
w_value)
         return self.w_total
 
+    def reduce_w(self):
+            space = self.space
+            w_total = space.w_None if self.w_total is None else self.w_total
+            w_func = space.w_None if self.w_func is None else self.w_func
+            return space.newtuple([space.gettypefor(W_Accumulate),
+                                   space.newtuple([self.w_iterable, w_func]), 
w_total])
+
 def W_Accumulate__new__(space, w_subtype, w_iterable, w_func=None):
     r = space.allocate_instance(W_Accumulate, w_subtype)
     r.__init__(space, space.iter(w_iterable), w_func)
@@ -1226,8 +1233,8 @@
     __new__  = interp2app(W_Accumulate__new__),
     __iter__ = interp2app(W_Accumulate.iter_w),
     __next__ = interp2app(W_Accumulate.next_w),
+    __reduce__ = interp2app(W_Accumulate.reduce_w),
     __doc__  = """\
 "accumulate(iterable) --> accumulate object
 
 Return series of accumulated sums.""")
-
diff --git a/pypy/module/itertools/test/test_itertools.py 
b/pypy/module/itertools/test/test_itertools.py
--- a/pypy/module/itertools/test/test_itertools.py
+++ b/pypy/module/itertools/test/test_itertools.py
@@ -2,7 +2,7 @@
 import pytest
 
 
-class AppTestItertools: 
+class AppTestItertools:
     spaceconfig = dict(usemodules=['itertools'])
 
     def test_count(self):
@@ -298,11 +298,11 @@
 
     def test_chain(self):
         import itertools
-        
+
         it = itertools.chain()
         raises(StopIteration, next, it)
         raises(StopIteration, next, it)
-        
+
         it = itertools.chain([1, 2, 3])
         for x in [1, 2, 3]:
             assert next(it) == x
@@ -322,7 +322,7 @@
 
         it = itertools.cycle([])
         raises(StopIteration, next, it)
-        
+
         it = itertools.cycle([1, 2, 3])
         for x in [1, 2, 3, 1, 2, 3, 1, 2, 3]:
             assert next(it) == x
@@ -378,7 +378,7 @@
 
     def test_tee_wrongargs(self):
         import itertools
-        
+
         raises(TypeError, itertools.tee, 0)
         raises(ValueError, itertools.tee, [], -1)
         raises(TypeError, itertools.tee, [], None)
@@ -416,7 +416,7 @@
 
     def test_groupby(self):
         import itertools
-        
+
         it = itertools.groupby([])
         raises(StopIteration, next, it)
 
@@ -493,7 +493,7 @@
             assert next(g) is x
             raises(StopIteration, next, g)
         raises(StopIteration, next, it)
-        
+
         # Grouping is based on key equality
         class AlwaysEqual(object):
             def __eq__(self, other):
@@ -516,7 +516,7 @@
 
     def test_iterables(self):
         import itertools
-    
+
         iterables = [
             itertools.chain(),
             itertools.count(),
@@ -531,7 +531,7 @@
             itertools.tee([])[0],
             itertools.tee([])[1],
             ]
-    
+
         for it in iterables:
             assert hasattr(it, '__iter__')
             assert iter(it) is it
@@ -540,7 +540,7 @@
 
     def test_docstrings(self):
         import itertools
-        
+
         assert itertools.__doc__
         methods = [
             itertools.chain,
@@ -981,7 +981,7 @@
         # kw arg
         assert list(accumulate(iterable=range(10))) == expected
         # multiple types
-        for typ in int, complex, Decimal, Fraction:                 
+        for typ in int, complex, Decimal, Fraction:
             assert list(accumulate(map(typ, range(10)))) == list(map(typ, 
expected))
         assert list(accumulate('abc')) == ['a', 'ab', 'abc']   # works with 
non-numeric
         assert list(accumulate([])) == []                  # empty iterable
@@ -998,3 +998,9 @@
         raises(TypeError, list, accumulate(s, chr))        # unary-operation
         raises(TypeError, list, accumulate(s, lambda x,y,z: None))  # ternary
 
+        it = iter([10, 50, 150])
+        a = accumulate(it)
+        assert a.__reduce__() == (accumulate, (it, None), None)
+        next(a)
+        next(a)
+        assert a.__reduce__() == (accumulate, (it, None), 60)
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to