Author: mattip
Branch: numppy-flatitter
Changeset: r51670:ff368a6d0ff7
Date: 2012-01-19 01:16 +0200
http://bitbucket.org/pypy/pypy/changeset/ff368a6d0ff7/

Log:    redo, add lots of tests, some still fail

diff --git a/pypy/module/micronumpy/interp_numarray.py 
b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -1290,53 +1290,100 @@
     tolist = interp2app(BaseArray.descr_tolist),
 )
 
+#TODO:Move all this to another file after fijal finishes reorganization
+def _to_coords(space, arr, w_item_or_slice):
+    '''Always returns a list of coords.
+    '''
+    start = stop = step = 0
+    if space.isinstance_w(w_item_or_slice, space.w_int):
+        start = space.int_w(w_item_or_slice)
+        if start < 0:
+            start += arr.size
+        stop = start+1
+        step = 1
+    elif space.isinstance_w(w_item_or_slice, space.w_slice):
+        start, stop, step, lngth = 
space.decode_index4(w_item_or_slice,arr.size)
+    else:
+        operationerrfmt(space.NotImplementedError,'cannot iterator over %s 
yet', w_item_or_slice)
+    retval = []
+    for i in range(start, stop, step):
+        coords = []
+        ii = i
+        if arr.order =='C':
+            for s in range(len(arr.shape) -1, -1, -1):
+                coords.insert(0,ii % arr.shape[s])
+                ii /= arr.shape[s]
+        else:
+            raise NotImplementedError
+            #untested code. Erase?
+            for s in range(len(arr.shape)):
+                coords.append(ii % arr.shape[s])
+                ii /= arr.shape[s]
+        if ii != 0:
+            raise OperationError(space.w_IndexError,
+                                 space.wrap("invalid index"))
+            
+        retval.append(space.newtuple([space.wrap(c) for c in coords]))
+    return retval 
 
 class W_FlatIterator(ViewArray):
 
     @jit.unroll_safe
     def __init__(self, arr):
         arr = arr.get_concrete()
-        size = 1
-        for sh in arr.shape:
-            size *= sh
-        if arr.strides[-1] <= arr.strides[0]:
-            self.strides = [arr.strides[-1]]
-            self.backstrides = [arr.backstrides[-1]]
-        else:
-            XXX
-            # This will not work: getitem and setitem will
-            # fail. Need to be smarter: calculate the indices from the int
-            self.strides = [arr.strides[0]]
-            self.backstrides = [arr.backstrides[0]]
-        ViewArray.__init__(self, size, [size], arr.dtype, order=arr.order,
-                               parent=arr)
         self.shapelen = len(arr.shape)
         sig = arr.find_sig()
-        #self.iter = OneDimIterator(arr.start, self.strides[0],
-        #                           self.shape[0])
         self.iter = sig.create_frame(arr).get_final_iter()
-        self.start = arr.start
         self.base = arr
+        self.index = 0
+        ViewArray.__init__(self, arr.size, [arr.size], arr.dtype, arr.order, 
arr)
 
     def descr_next(self, space):
         if self.iter.done():
             raise OperationError(space.w_StopIteration, space.w_None)
-        result = self.getitem(self.iter.offset)
+        result = self.base.getitem(self.iter.offset)
         self.iter = self.iter.next(self.shapelen)
+        self.index += 1
         return result
 
     def descr_iter(self):
         return self
 
     def descr_index(self, space):
-        return space.wrap(self.iter.offset)
+        return space.wrap(self.index)
 
+    def descr_coords(self, space):
+        return _to_coords(space, self.base, space.wrap(self.index))[0]
+
+    def descr_getitem(self, space, w_idx):
+        coords = _to_coords(space, self.base, w_idx)
+        if len(coords)>1:
+            w_result = W_NDimArray(len(coords), [len(coords)], self.base.dtype,
+                                        self.base.order)
+            for i,c in enumerate(coords):
+                w_val = self.base.descr_getitem(space, c)
+                w_result.setitem(i,w_val)
+            return w_result
+        else:
+            return self.base.descr_getitem(space, coords[0])
+
+    def descr_setitem(self, space, w_idx, w_value):
+        coords = _to_coords(space, self.base, w_idx)
+        arr = convert_to_array(space, w_value)
+        ai = 0
+        for c in coords:
+            v = arr.getitem(ai)
+            self.base.descr_setitem(space, c,v)
+            ai = (ai + 1) % arr.size
+        
 W_FlatIterator.typedef = TypeDef(
     'flatiter',
     next = interp2app(W_FlatIterator.descr_next),
     __iter__ = interp2app(W_FlatIterator.descr_iter),
-    __getitem__ = interp2app(BaseArray.descr_getitem),
-    __setitem__ = interp2app(BaseArray.descr_setitem),
+    __getitem__ = interp2app(W_FlatIterator.descr_getitem),
+    __setitem__ = interp2app(W_FlatIterator.descr_setitem),
     index = GetSetProperty(W_FlatIterator.descr_index),
+    coords = GetSetProperty(W_FlatIterator.descr_coords),
+
 )
 W_FlatIterator.acceptable_as_base_class = False
diff --git a/pypy/module/micronumpy/test/test_numarray.py 
b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -1300,7 +1300,7 @@
         assert (a + a).flat[3] == 6
         assert a[::2].flat[3] == 6
         assert a.reshape(2,5).flat[3] == 3
-        b = a.flat
+        b = a.reshape(2,5).flat
         b.next()
         b.next()
         b.next()
@@ -1309,6 +1309,19 @@
         raises(IndexError, "b[11]")
         raises(IndexError, "b[-11]")
         assert b.index == 3
+        assert b.coords == (0,3)
+
+    def test_flatiter_setitem(self):
+        from _numpypy import arange, array
+        a = arange(12).reshape(3,4)
+        b = a.T.flat
+        b[6::2] = [-1, -2]
+        assert (a == [[0, 1, -1, 3], [4, 5, 6, -1], [8, 9, -2, 11]]).all()
+        b[1:2] = [[[100]]]
+        assert(a[0,0] == 100)
+        assert(a[1,0] == 100)
+        b[array([10, 11])] == [-20, -40]
+        
 
     def test_flatiter_view(self):
         from _numpypy import arange
@@ -1323,8 +1336,14 @@
 
     def test_flatiter_transpose(self):
         from _numpypy import arange
-        a = arange(10)
-        assert a.reshape(2,5).T.flat[3] == 6
+        a = arange(10).reshape(2,5).T
+        b = a.flat
+        assert (b[:5] == [0, 5, 1, 6, 2]).all()
+        b.next()
+        b.next()
+        b.next()
+        assert b.index == 3
+        assert b.coords == (1,1)
 
     def test_slice_copy(self):
         from _numpypy import zeros
_______________________________________________
pypy-commit mailing list
pypy-commit@python.org
http://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to