Author: Brian Kearns <[email protected]>
Branch:
Changeset: r68526:016e3a6acf4b
Date: 2013-12-21 00:35 -0500
http://bitbucket.org/pypy/pypy/changeset/016e3a6acf4b/
Log: fix take with multidimensional indices argument
diff --git a/pypy/module/micronumpy/interp_numarray.py
b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -1185,7 +1185,9 @@
def take(a, indices, axis, out, mode):
assert mode == 'raise'
if axis is None:
- res = a.ravel()[indices]
+ from numpy import array
+ indices = array(indices)
+ res = a.ravel()[indices.ravel()].reshape(indices.shape)
else:
from operator import mul
if axis < 0: axis += len(a.shape)
diff --git a/pypy/module/micronumpy/test/test_numarray.py
b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -2726,7 +2726,10 @@
assert (arange(10).take([1, 2, 1, 1]) == [1, 2, 1, 1]).all()
raises(IndexError, "arange(3).take([15])")
a = arange(6).reshape(2, 3)
+ assert a.take(3) == 3
+ assert a.take(3).shape == ()
assert (a.take([1, 0, 3]) == [1, 0, 3]).all()
+ assert (a.take([[1, 0], [2, 3]]) == [[1, 0], [2, 3]]).all()
assert (a.take([1], axis=0) == [[3, 4, 5]]).all()
assert (a.take([1], axis=1) == [[1], [4]]).all()
assert ((a + a).take([3]) == [6]).all()
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit