Author: Sergey Matyunin <sbmatyu...@gmail.com>
Branch: numpy_broadcast_nd
Changeset: r84066:d52b849b3779
Date: 2016-04-24 13:54 +0200
http://bitbucket.org/pypy/pypy/changeset/d52b849b3779/

Log:    Implemented reset for numpy broadcast object.

diff --git a/pypy/module/micronumpy/broadcast.py 
b/pypy/module/micronumpy/broadcast.py
--- a/pypy/module/micronumpy/broadcast.py
+++ b/pypy/module/micronumpy/broadcast.py
@@ -75,6 +75,11 @@
             return res[0]
         return space.newtuple(res)
 
+    def descr_reset(self, space):
+        self.index = 0
+        self.done = False
+        for it in self.list_iter_state:
+            it.reset()
 
 W_Broadcast.typedef = TypeDef("numpy.broadcast",
                               __new__=interp2app(descr_new_broadcast),
@@ -86,4 +91,5 @@
                               
numiter=GetSetProperty(W_Broadcast.descr_get_numiter),
                               
nd=GetSetProperty(W_Broadcast.descr_get_number_of_dimensions),
                               
iters=GetSetProperty(W_Broadcast.descr_get_iters),
+                              reset=interp2app(W_Broadcast.descr_reset),
                               )
diff --git a/pypy/module/micronumpy/flatiter.py 
b/pypy/module/micronumpy/flatiter.py
--- a/pypy/module/micronumpy/flatiter.py
+++ b/pypy/module/micronumpy/flatiter.py
@@ -76,7 +76,7 @@
                                          base.get_order(), w_instance=base)
             return loop.flatiter_getitem(res, self.iter, state, step)
         finally:
-            self.iter.reset(self.state, mutate=True)
+            self.reset()
 
     def descr_setitem(self, space, w_idx, w_value):
         if not (space.isinstance_w(w_idx, space.w_int) or
@@ -96,11 +96,14 @@
             arr = convert_to_array(space, w_value)
             loop.flatiter_setitem(space, dtype, arr, self.iter, state, step, 
length)
         finally:
-            self.iter.reset(self.state, mutate=True)
+            self.reset()
 
     def descr___array_wrap__(self, space, obj, w_context=None):
         return obj
 
+    def reset(self):
+        self.iter.reset(self.state, mutate=True)
+
 W_FlatIterator.typedef = TypeDef("numpy.flatiter",
     base = GetSetProperty(W_FlatIterator.descr_base),
     index = GetSetProperty(W_FlatIterator.descr_index),
diff --git a/pypy/module/micronumpy/test/test_broadcast.py 
b/pypy/module/micronumpy/test/test_broadcast.py
--- a/pypy/module/micronumpy/test/test_broadcast.py
+++ b/pypy/module/micronumpy/test/test_broadcast.py
@@ -123,3 +123,15 @@
         assert step_in_y == y[0, 0]  # == 3
         assert step_in_broadcast == (1, 3)
         assert step2_in_y == y[1, 0]  # == 4
+
+    def test_broadcast_reset(self):
+        import numpy as np
+        x = np.array([1, 2, 3])
+        y = np.array([[4], [5], [6]])
+
+        b = np.broadcast(x, y)
+        b.next(), b.next(), b.next()
+        b.reset()
+
+        assert b.index == 0
+        assert b.next() == (1, 4)
_______________________________________________
pypy-commit mailing list
pypy-commit@python.org
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to