Author: Maciej Fijalkowski <[email protected]>
Branch: numpy-indexing-by-arrays-2
Changeset: r51395:6fe5770303bb
Date: 2012-01-17 14:28 +0200
http://bitbucket.org/pypy/pypy/changeset/6fe5770303bb/
Log: fix setitem with bool index
diff --git a/pypy/module/micronumpy/interp_numarray.py
b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -537,16 +537,18 @@
return res
def setitem_filter(self, space, idx, val):
- arr = SliceArray(self.shape, self.dtype, self, val)
- shapelen = len(arr.shape)
+ size = self.count_all_true(idx)
+ arr = SliceArray([size], self.dtype, self, val)
sig = arr.find_sig()
+ shapelen = len(self.shape)
frame = sig.create_frame(arr)
idxi = idx.create_iter()
while not frame.done():
if idx.dtype.getitem_bool(idx.storage, idxi.offset):
sig.eval(frame, arr)
+ frame.next_from_second(1)
+ frame.next_first(shapelen)
idxi = idxi.next(shapelen)
- frame.next(shapelen)
def descr_getitem(self, space, w_idx):
if (isinstance(w_idx, BaseArray) and w_idx.shape == self.shape and
diff --git a/pypy/module/micronumpy/signature.py
b/pypy/module/micronumpy/signature.py
--- a/pypy/module/micronumpy/signature.py
+++ b/pypy/module/micronumpy/signature.py
@@ -82,6 +82,16 @@
for i in range(len(self.iterators)):
self.iterators[i] = self.iterators[i].next(shapelen)
+ @unroll_safe
+ def next_from_second(self, shapelen):
+ """ Don't increase the first iterator
+ """
+ for i in range(1, len(self.iterators)):
+ self.iterators[i] = self.iterators[i].next(shapelen)
+
+ def next_first(self, shapelen):
+ self.iterators[0] = self.iterators[0].next(shapelen)
+
def get_final_iter(self):
final_iter = promote(self.final_iter)
if final_iter < 0:
diff --git a/pypy/module/micronumpy/test/test_numarray.py
b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -1327,7 +1327,7 @@
assert (a == [0, 1, 2, 3, 15, 15]).all()
a = arange(6).reshape(3, 2)
a[a & 1 == 1] = array([8, 9, 10])
- assert (a == [[0, 8], [3, 9], [5, 10]]).all()
+ assert (a == [[0, 8], [2, 9], [4, 10]]).all()
class AppTestSupport(BaseNumpyAppTest):
def setup_class(cls):
_______________________________________________
pypy-commit mailing list
[email protected]
http://mail.python.org/mailman/listinfo/pypy-commit