Author: Matti Picus <[email protected]>
Branch: indexing-by-array
Changeset: r62407:fc9348336eda
Date: 2013-03-18 14:04 -0700
http://bitbucket.org/pypy/pypy/changeset/fc9348336eda/
Log: broadcast only when needed, fix compress for nd array with no axis
arg (rguillebert, mattip)
diff --git a/pypy/module/micronumpy/interp_numarray.py
b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -79,12 +79,10 @@
raise OperationError(space.w_ValueError,
space.wrap("index out of range for array"))
size = loop.count_all_true(arr)
- print 'size',size
if len(arr.get_shape()) == 1:
res_shape = [size] + self.get_shape()[1:]
else:
res_shape = [size]
- print 'res_shape',res_shape
res = W_NDimArray.from_shape(res_shape, self.get_dtype())
return loop.getitem_filter(res, self, arr)
@@ -369,8 +367,11 @@
if not space.is_none(w_axis):
raise OperationError(space.w_NotImplementedError,
space.wrap("axis unsupported for compress"))
+ arr = self
+ else:
+ arr = self.descr_reshape(space, [space.wrap(-1)])
index = convert_to_array(space, w_obj)
- return self.getitem_filter(space, index)
+ return arr.getitem_filter(space, index)
def descr_flatten(self, space, w_order=None):
if self.is_scalar():
diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -300,11 +300,12 @@
def getitem_filter(res, arr, index):
res_iter = res.create_iter()
- print 'index.size', index.get_size()
- print 'res.size', res.get_size()
- index_iter = index.create_iter(arr.get_shape(), backward_broadcast=True)
+ shapelen = len(arr.get_shape())
+ if shapelen > 1 and len(index.get_shape()) < 2:
+ index_iter = index.create_iter(arr.get_shape(),
backward_broadcast=True)
+ else:
+ index_iter = index.create_iter()
arr_iter = arr.create_iter()
- shapelen = len(arr.get_shape())
arr_dtype = arr.get_dtype()
index_dtype = index.get_dtype()
# XXX length of shape of index as well?
@@ -313,7 +314,6 @@
index_dtype=index_dtype,
arr_dtype=arr_dtype,
)
- print 'res,arr,index', res_iter.offset, arr_iter.offset,
index_iter.offset, index_iter.getitem_bool()
if index_iter.getitem_bool():
res_iter.setitem(arr_iter.getitem())
res_iter.next()
diff --git a/pypy/module/micronumpy/test/test_numarray.py
b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -2155,26 +2155,16 @@
def test_compress(self):
from numpypy import arange, array
a = arange(10)
- print 0
assert (a.compress([True, False, True]) == [0, 2]).all()
- print 1
assert (a.compress([1, 0, 13]) == [0, 2]).all()
- print 2
assert (a.compress([1, 0, 13]) == [0, 2]).all()
- print '2a'
assert (a.compress([1, 0, 13.5]) == [0, 2]).all()
- print 3
assert (a.compress(array([1, 0, 13.5], dtype='>f4')) == [0, 2]).all()
- print 4
assert (a.compress(array([1, 0, 13.5], dtype='<f4')) == [0, 2]).all()
- print 5
assert (a.compress([1, -0-0j, 1.3+13.5j]) == [0, 2]).all()
- print 6
a = arange(10).reshape(2, 5)
assert (a.compress([True, False, True]) == [0, 2]).all()
- print 7
raises((IndexError, ValueError), "a.compress([1] * 100)")
- print 8
def test_item(self):
from numpypy import array
_______________________________________________
pypy-commit mailing list
[email protected]
http://mail.python.org/mailman/listinfo/pypy-commit