Author: Maciej Fijalkowski <[email protected]>
Branch: numpy-indexing-by-arrays-2
Changeset: r51395:6fe5770303bb
Date: 2012-01-17 14:28 +0200
http://bitbucket.org/pypy/pypy/changeset/6fe5770303bb/

Log:    fix setitem with bool index

diff --git a/pypy/module/micronumpy/interp_numarray.py 
b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -537,16 +537,18 @@
         return res
 
     def setitem_filter(self, space, idx, val):
-        arr = SliceArray(self.shape, self.dtype, self, val)
-        shapelen = len(arr.shape)
+        size = self.count_all_true(idx)
+        arr = SliceArray([size], self.dtype, self, val)
         sig = arr.find_sig()
+        shapelen = len(self.shape)
         frame = sig.create_frame(arr)
         idxi = idx.create_iter()
         while not frame.done():
             if idx.dtype.getitem_bool(idx.storage, idxi.offset):
                 sig.eval(frame, arr)
+                frame.next_from_second(1)
+            frame.next_first(shapelen)
             idxi = idxi.next(shapelen)
-            frame.next(shapelen)
 
     def descr_getitem(self, space, w_idx):
         if (isinstance(w_idx, BaseArray) and w_idx.shape == self.shape and
diff --git a/pypy/module/micronumpy/signature.py 
b/pypy/module/micronumpy/signature.py
--- a/pypy/module/micronumpy/signature.py
+++ b/pypy/module/micronumpy/signature.py
@@ -82,6 +82,16 @@
         for i in range(len(self.iterators)):
             self.iterators[i] = self.iterators[i].next(shapelen)
 
+    @unroll_safe
+    def next_from_second(self, shapelen):
+        """ Don't increase the first iterator
+        """
+        for i in range(1, len(self.iterators)):
+            self.iterators[i] = self.iterators[i].next(shapelen)
+
+    def next_first(self, shapelen):
+        self.iterators[0] = self.iterators[0].next(shapelen)
+
     def get_final_iter(self):
         final_iter = promote(self.final_iter)
         if final_iter < 0:
diff --git a/pypy/module/micronumpy/test/test_numarray.py 
b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -1327,7 +1327,7 @@
         assert (a == [0, 1, 2, 3, 15, 15]).all()
         a = arange(6).reshape(3, 2)
         a[a & 1 == 1] = array([8, 9, 10])
-        assert (a == [[0, 8], [3, 9], [5, 10]]).all()
+        assert (a == [[0, 8], [2, 9], [4, 10]]).all()
 
 class AppTestSupport(BaseNumpyAppTest):
     def setup_class(cls):
_______________________________________________
pypy-commit mailing list
[email protected]
http://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to