Author: Maciej Fijalkowski <[email protected]>
Branch: numpy-fancy-indexing
Changeset: r57386:f8761cec0456
Date: 2012-09-19 15:18 +0200
http://bitbucket.org/pypy/pypy/changeset/f8761cec0456/
Log: pass the next test
diff --git a/pypy/module/micronumpy/arrayimpl/concrete.py
b/pypy/module/micronumpy/arrayimpl/concrete.py
--- a/pypy/module/micronumpy/arrayimpl/concrete.py
+++ b/pypy/module/micronumpy/arrayimpl/concrete.py
@@ -227,6 +227,7 @@
@jit.unroll_safe
def _lookup_by_unwrapped_index(self, space, lst):
item = self.start
+ assert len(lst) == len(self.shape)
for i, idx in enumerate(lst):
if idx < 0:
idx = self.shape[i] + idx
diff --git a/pypy/module/micronumpy/interp_numarray.py
b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -86,7 +86,8 @@
else:
arr = convert_to_array(space, w_index)
return arr.get_shape(), [arr]
- xxx
+ xxx # determine shape
+ return w_lst
def getitem_array_int(self, space, w_index):
iter_shape, indexes = self._prepare_array_index(space, w_index)
diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -4,9 +4,10 @@
"""
from pypy.rlib.objectmodel import specialize
+from pypy.rlib.rstring import StringBuilder
+from pypy.rlib import jit
+from pypy.rpython.lltypesystem import lltype, rffi
from pypy.module.micronumpy.base import W_NDimArray
-from pypy.rlib.rstring import StringBuilder
-from pypy.rpython.lltypesystem import lltype, rffi
def call2(shape, func, name, calc_dtype, res_dtype, w_lhs, w_rhs, out):
if out is None:
@@ -245,26 +246,63 @@
iter.next()
return builder.build()
-def getitem_array_int(space, arr, res, iter_shape, indexes):
- assert len(indexes) == 1
- assert len(iter_shape) == 1
- res_iter = res.create_iter() # this shape is whatever shape res comes in
- index_iter = indexes[0].create_iter()
- while not index_iter.done():
- idx = space.int_w(index_iter.getitem())
- res_iter.setitem(arr.getitem(space, [idx]))
- index_iter.next()
- res_iter.next()
+class PureShapeIterator(object):
+ def __init__(self, shape, idx_w):
+ self.shape = shape
+ self.shapelen = len(shape)
+ self.indexes = [0] * len(shape)
+ self._done = False
+ self.idx_w = [None] * len(idx_w)
+ for i, w_idx in enumerate(idx_w):
+ if isinstance(w_idx, W_NDimArray):
+ self.idx_w[i] = w_idx.create_iter(shape)
+
+ def done(self):
+ return self._done
+
+ @jit.unroll_safe
+ def next(self):
+ for w_idx in self.idx_w:
+ if w_idx is not None:
+ w_idx.next()
+ for i in range(self.shapelen - 1, -1, -1):
+ if self.indexes[i] < self.shape[i] - 1:
+ self.indexes[i] += 1
+ break
+ else:
+ self.indexes[i] = 0
+ else:
+ self._done = True
+
+ def get_index(self, space):
+ return space.newtuple([space.wrap(i) for i in self.indexes])
+
+def getitem_array_int(space, arr, res, iter_shape, indexes_w):
+ iter = PureShapeIterator(iter_shape, indexes_w)
+ while not iter.done():
+ # prepare the index
+ index_w = [None] * len(iter_shape)
+ for i in range(len(iter_shape)):
+ if iter.idx_w[i] is not None:
+ index_w[i] = iter.idx_w[i].getitem()
+ else:
+ index_w[i] = indexes_w[i]
+ res.descr_setitem(space, iter.get_index(space),
+ arr.descr_getitem(space, space.newtuple(index_w)))
+ iter.next()
return res
-def setitem_array_int(space, arr, iter_shape, indexes, val_arr):
- assert len(indexes) == 1
- assert len(iter_shape) == 1
- index_iter = indexes[0].create_iter()
- dtype = arr.get_dtype()
- val_iter = val_arr.create_iter(iter_shape)
- while not index_iter.done():
- idx = space.int_w(index_iter.getitem())
- arr.setitem(space, [idx], val_iter.getitem().convert_to(dtype))
- val_iter.next()
- index_iter.next()
+def setitem_array_int(space, arr, iter_shape, indexes_w, val_arr):
+ iter = PureShapeIterator(iter_shape, indexes_w)
+ while not iter.done():
+ # prepare the index
+ index_w = [None] * len(iter_shape)
+ for i in range(len(iter_shape)):
+ if iter.idx_w[i] is not None:
+ index_w[i] = iter.idx_w[i].getitem()
+ else:
+ index_w[i] = indexes_w[i]
+ arr.descr_setitem(space, space.newtuple(index_w),
+ val_arr.descr_getitem(space, iter.get_index(space)))
+ iter.next()
+
diff --git a/pypy/module/micronumpy/test/test_numarray.py
b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -1992,7 +1992,12 @@
assert (a + a).item(1) == 4
raises(IndexError, "array(5).item(1)")
assert array([1]).item() == 1
-
+
+ def test_int_array_index(self):
+ from _numpypy import array
+ a = array([[1, 2], [3, 4]])
+ b = a[array([0, 0])]
+ assert (b == [[1, 2], [1, 2]]).all()
class AppTestSupport(BaseNumpyAppTest):
def setup_class(cls):
_______________________________________________
pypy-commit mailing list
[email protected]
http://mail.python.org/mailman/listinfo/pypy-commit