Author: Brian Kearns <[email protected]>
Branch:
Changeset: r68500:aca7d2177494
Date: 2013-12-19 18:37 -0500
http://bitbucket.org/pypy/pypy/changeset/aca7d2177494/
Log: fix ndarray.take with axis argument
diff --git a/pypy/module/micronumpy/interp_numarray.py
b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -1187,10 +1187,11 @@
if axis is None:
res = a.ravel()[indices]
else:
+ from operator import mul
if axis < 0: axis += len(a.shape)
s0, s1 = a.shape[:axis], a.shape[axis+1:]
- l0 = prod(s0) if s0 else 1
- l1 = prod(s1) if s1 else 1
+ l0 = reduce(mul, s0) if s0 else 1
+ l1 = reduce(mul, s1) if s1 else 1
res = a.reshape((l0, -1, l1))[:,indices,:].reshape(s0 + (-1,) + s1)
if out is not None:
out[:] = res
diff --git a/pypy/module/micronumpy/test/test_numarray.py
b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -2722,6 +2722,8 @@
raises(IndexError, "arange(3).take([15])")
a = arange(6).reshape(2, 3)
assert (a.take([1, 0, 3]) == [1, 0, 3]).all()
+ assert (a.take([1], axis=0) == [[3, 4, 5]]).all()
+ assert (a.take([1], axis=1) == [[1], [4]]).all()
assert ((a + a).take([3]) == [6]).all()
a = arange(12).reshape(2, 6)
assert (a[:,::2].take([3, 2, 1]) == [6, 4, 2]).all()
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit