Author: Matti Picus <[email protected]>
Branch: 
Changeset: r67397:eff24d19da2b
Date: 2013-10-15 19:18 +0300
http://bitbucket.org/pypy/pypy/changeset/eff24d19da2b/

Log:    test, implement ndarray.flat = val

diff --git a/pypy/module/micronumpy/interp_numarray.py 
b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -399,6 +399,10 @@
     def descr_repeat(self, space, repeats, w_axis=None):
         return repeat(space, self, repeats, w_axis)
 
+    def descr_set_flatiter(self, space, w_obj):
+        arr = convert_to_array(space, w_obj)
+        loop.flatiter_setitem(space, self, arr, 0, 1, self.get_size())
+
     def descr_get_flatiter(self, space):
         return space.wrap(W_FlatIterator(self))
 
@@ -1130,7 +1134,8 @@
     repeat = interp2app(W_NDimArray.descr_repeat),
     swapaxes = interp2app(W_NDimArray.descr_swapaxes),
     nonzero = interp2app(W_NDimArray.descr_nonzero),
-    flat = GetSetProperty(W_NDimArray.descr_get_flatiter),
+    flat = GetSetProperty(W_NDimArray.descr_get_flatiter,
+                          W_NDimArray.descr_set_flatiter),
     item = interp2app(W_NDimArray.descr_item),
     real = GetSetProperty(W_NDimArray.descr_get_real,
                           W_NDimArray.descr_set_real),
diff --git a/pypy/module/micronumpy/test/test_numarray.py 
b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -2273,6 +2273,24 @@
         assert len(arange(10)[:2].flat) == 2
         assert len((arange(2) + arange(2)).flat) == 2
 
+    def test_flatiter_setter(self):
+        from numpypy import arange, array
+        a = arange(24).reshape(2, 3, 4)
+        a.flat = [4, 5]
+        assert (a.flatten() == [4, 5]*12).all()
+        a.flat = [[4, 5, 6, 7, 8], [4, 5, 6, 7, 8]]
+        assert (a.flatten() == ([4, 5, 6, 7, 8]*5)[:24]).all()
+        exc = raises(ValueError, 'a.flat = [[4, 5, 6, 7, 8], [4, 5, 6]]')
+        assert str(exc.value).find("sequence") > 0
+        b = a[::-1, :, ::-1]
+        b.flat = range(24)
+        assert (a.flatten() == [15, 14 ,13, 12, 19, 18, 17, 16, 23, 22,
+                                21, 20, 3, 2, 1, 0, 7, 6, 5, 4,
+                                11, 10, 9, 8]).all()
+        c = array(['abc'] * 10).reshape(2, 5)
+        c.flat = ['defgh', 'ijklmnop']
+        assert (c.flatten() == ['def', 'ijk']*5).all()
+
     def test_slice_copy(self):
         from numpypy import zeros
         a = zeros((10, 10))
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to