Author: Matti Picus <[email protected]>
Branch: boolean-indexing-cleanup
Changeset: r66973:c206de17847f
Date: 2013-09-16 21:07 +0300
http://bitbucket.org/pypy/pypy/changeset/c206de17847f/
Log: 'fix' boolean assignment by allowing creation of simple iterators if
shape matches (numpy compatibility)
diff --git a/pypy/module/micronumpy/arrayimpl/concrete.py
b/pypy/module/micronumpy/arrayimpl/concrete.py
--- a/pypy/module/micronumpy/arrayimpl/concrete.py
+++ b/pypy/module/micronumpy/arrayimpl/concrete.py
@@ -316,7 +316,8 @@
self.storage = storage
def create_iter(self, shape=None, backward_broadcast=False):
- if shape is None or shape == self.get_shape():
+ if shape is None or \
+ support.product(shape) == support.product(self.get_shape()):
return iter.ConcreteArrayIterator(self)
r = calculate_broadcast_strides(self.get_strides(),
self.get_backstrides(),
diff --git a/pypy/module/micronumpy/interp_numarray.py
b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -97,13 +97,13 @@
space.wrap("index out of range for array"))
idx_iter = idx.create_iter(self.get_shape())
size = loop.count_all_true_iter(idx_iter, self.get_shape(),
idx.get_dtype())
- if len(val.get_shape()) > 0 and val.get_shape()[0] > 1 and size >
val.get_shape()[0]:
+ if size != val.get_size() and val.get_size() > 1:
raise OperationError(space.w_ValueError, space.wrap("NumPy boolean
array indexing assignment "
"cannot assign
%d input values to "
- "the %d output
values where the mask is true" % (val.get_shape()[0],size)))
+ "the %d output
values where the mask is true" % (val.get_size(), size)))
if val.get_shape() == [0]:
val.implementation.dtype = self.implementation.dtype
- loop.setitem_filter(self, idx, val)
+ loop.setitem_filter(self, idx, val, size)
def _prepare_array_index(self, space, w_index):
if isinstance(w_index, W_NDimArray):
diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -372,10 +372,10 @@
'index_dtype'],
reds = 'auto')
-def setitem_filter(arr, index, value):
+def setitem_filter(arr, index, value, size):
arr_iter = arr.create_iter()
index_iter = index.create_iter(arr.get_shape())
- value_iter = value.create_iter(arr.get_shape())
+ value_iter = value.create_iter([size])
shapelen = len(arr.get_shape())
index_dtype = index.get_dtype()
arr_dtype = arr.get_dtype()
diff --git a/pypy/module/micronumpy/test/test_numarray.py
b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -2355,11 +2355,11 @@
def test_array_indexing_bool_specialcases(self):
from numpypy import arange, array
a = arange(6)
- try:
- a[a < 3] = [1, 2]
- assert False, "Should not work"
- except ValueError:
- pass
+ exc = raises(ValueError,'a[a < 3] = [1, 2]')
+ assert exc.value[0].find('cannot assign') >= 0
+ b = arange(4).reshape(2, 2) + 10
+ a[a < 4] = b
+ assert (a == [10, 11, 12, 13, 4, 5]).all()
a = arange(6)
a[a > 3] = array([15])
assert (a == [0, 1, 2, 3, 15, 15]).all()
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit