Author: Maciej Fijalkowski <[email protected]>
Branch: numpy-refactor
Changeset: r57188:67acfa0c2475
Date: 2012-09-06 21:30 +0200
http://bitbucket.org/pypy/pypy/changeset/67acfa0c2475/

Log:    setiter

diff --git a/pypy/module/micronumpy/arrayimpl/concrete.py 
b/pypy/module/micronumpy/arrayimpl/concrete.py
--- a/pypy/module/micronumpy/arrayimpl/concrete.py
+++ b/pypy/module/micronumpy/arrayimpl/concrete.py
@@ -35,6 +35,9 @@
     def done(self):
         return self.offset >= self.size
 
+    def reset(self):
+        self.offset %= self.size
+
 class OneDimViewIterator(ConcreteArrayIterator):
     def __init__(self, array):
         self.array = array
@@ -55,6 +58,9 @@
     def done(self):
         return self.index >= self.size
 
+    def reset(self):
+        self.offset %= self.size
+
 class MultiDimViewIterator(ConcreteArrayIterator):
     def __init__(self, array, start, strides, backstrides, shape):
         self.indexes = [0] * len(shape)
@@ -65,6 +71,7 @@
         self._done = False
         self.strides = strides
         self.backstrides = backstrides
+        self.size = array.size
 
     @jit.unroll_safe
     def next(self):
@@ -100,6 +107,9 @@
     def done(self):
         return self._done
 
+    def reset(self):
+        self.offset %= self.size
+
 class AxisIterator(base.BaseArrayIterator):
     def __init__(self, array, shape, dim):
         self.shape = shape
diff --git a/pypy/module/micronumpy/arrayimpl/scalar.py 
b/pypy/module/micronumpy/arrayimpl/scalar.py
--- a/pypy/module/micronumpy/arrayimpl/scalar.py
+++ b/pypy/module/micronumpy/arrayimpl/scalar.py
@@ -20,6 +20,9 @@
     def done(self):
         return False
 
+    def reset(self):
+        pass
+
 class Scalar(base.BaseArrayImplementation):
     def __init__(self, dtype, value=None):
         self.value = value
diff --git a/pypy/module/micronumpy/interp_flatiter.py 
b/pypy/module/micronumpy/interp_flatiter.py
--- a/pypy/module/micronumpy/interp_flatiter.py
+++ b/pypy/module/micronumpy/interp_flatiter.py
@@ -1,5 +1,5 @@
 
-from pypy.module.micronumpy.base import W_NDimArray
+from pypy.module.micronumpy.base import W_NDimArray, convert_to_array
 from pypy.module.micronumpy import loop
 from pypy.module.micronumpy.strides import to_coords
 from pypy.interpreter.baseobjspace import Wrappable
@@ -43,8 +43,6 @@
         self.reset()
         base = self.base
         start, stop, step, length = space.decode_index4(w_idx, base.get_size())
-        # setslice would have been better, but flat[u:v] for arbitrary
-        # shapes of array a cannot be represented as a[x1:x2, y1:y2]
         base_iter = base.create_iter()
         base_iter.next_skip_x(start)
         if length == 1:
@@ -53,6 +51,16 @@
                                      base.get_order())
         return loop.flatiter_getitem(res, base_iter, step)
 
+    def descr_setitem(self, space, w_idx, w_value):
+        if not (space.isinstance_w(w_idx, space.w_int) or
+            space.isinstance_w(w_idx, space.w_slice)):
+            raise OperationError(space.w_IndexError,
+                                 space.wrap('unsupported iterator index'))
+        base = self.base
+        start, stop, step, length = space.decode_index4(w_idx, base.get_size())
+        arr = convert_to_array(space, w_value)
+        loop.flatiter_setitem(self.base, arr, start, step, length)
+
     def descr_iter(self):
         return self
 
@@ -63,6 +71,7 @@
     'flatiter',
     __iter__ = interp2app(W_FlatIterator.descr_iter),
     __getitem__ = interp2app(W_FlatIterator.descr_getitem),
+    __setitem__ = interp2app(W_FlatIterator.descr_setitem),
     __len__ = interp2app(W_FlatIterator.descr_len),
 
     next = interp2app(W_FlatIterator.descr_next),
diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -194,3 +194,17 @@
         base_iter.next_skip_x(step)
         ri.next()
     return res
+
+def flatiter_setitem(arr, val, start, step, length):
+    dtype = arr.get_dtype()
+    arr_iter = arr.create_iter()
+    val_iter = val.create_iter()
+    arr_iter.next_skip_x(start)
+    while length > 0:
+        arr_iter.setitem(val_iter.getitem().convert_to(dtype))
+        # need to repeat i_nput values until all assignments are done
+        arr_iter.next_skip_x(step)
+        length -= 1
+        val_iter.next()
+        # WTF numpy?
+        val_iter.reset()
diff --git a/pypy/module/micronumpy/test/test_numarray.py 
b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -1831,7 +1831,6 @@
         a = arange(12).reshape(3,4)
         b = a.T.flat
         b[6::2] = [-1, -2]
-        print a == [[0, 1, -1, 3], [4, 5, 6, -1], [8, 9, -2, 11]]
         assert (a == [[0, 1, -1, 3], [4, 5, 6, -1], [8, 9, -2, 11]]).all()
         b[0:2] = [[[100]]]
         assert(a[0,0] == 100)
_______________________________________________
pypy-commit mailing list
[email protected]
http://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to