Author: Maciej Fijalkowski <[email protected]>
Branch: numpy-fancy-indexing
Changeset: r57386:f8761cec0456
Date: 2012-09-19 15:18 +0200
http://bitbucket.org/pypy/pypy/changeset/f8761cec0456/

Log:    pass the next test

diff --git a/pypy/module/micronumpy/arrayimpl/concrete.py 
b/pypy/module/micronumpy/arrayimpl/concrete.py
--- a/pypy/module/micronumpy/arrayimpl/concrete.py
+++ b/pypy/module/micronumpy/arrayimpl/concrete.py
@@ -227,6 +227,7 @@
     @jit.unroll_safe
     def _lookup_by_unwrapped_index(self, space, lst):
         item = self.start
+        assert len(lst) == len(self.shape)
         for i, idx in enumerate(lst):
             if idx < 0:
                 idx = self.shape[i] + idx
diff --git a/pypy/module/micronumpy/interp_numarray.py 
b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -86,7 +86,8 @@
         else:
             arr = convert_to_array(space, w_index)
             return arr.get_shape(), [arr]
-        xxx
+        xxx # determine shape
+        return w_lst
 
     def getitem_array_int(self, space, w_index):
         iter_shape, indexes = self._prepare_array_index(space, w_index)
diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -4,9 +4,10 @@
 """
 
 from pypy.rlib.objectmodel import specialize
+from pypy.rlib.rstring import StringBuilder
+from pypy.rlib import jit
+from pypy.rpython.lltypesystem import lltype, rffi
 from pypy.module.micronumpy.base import W_NDimArray
-from pypy.rlib.rstring import StringBuilder
-from pypy.rpython.lltypesystem import lltype, rffi
 
 def call2(shape, func, name, calc_dtype, res_dtype, w_lhs, w_rhs, out):
     if out is None:
@@ -245,26 +246,63 @@
         iter.next()
     return builder.build()
 
-def getitem_array_int(space, arr, res, iter_shape, indexes):
-    assert len(indexes) == 1
-    assert len(iter_shape) == 1
-    res_iter = res.create_iter() # this shape is whatever shape res comes in
-    index_iter = indexes[0].create_iter()
-    while not index_iter.done():
-        idx = space.int_w(index_iter.getitem())
-        res_iter.setitem(arr.getitem(space, [idx]))
-        index_iter.next()
-        res_iter.next()
+class PureShapeIterator(object):
+    def __init__(self, shape, idx_w):
+        self.shape = shape
+        self.shapelen = len(shape)
+        self.indexes = [0] * len(shape)
+        self._done = False
+        self.idx_w = [None] * len(idx_w)
+        for i, w_idx in enumerate(idx_w):
+            if isinstance(w_idx, W_NDimArray):
+                self.idx_w[i] = w_idx.create_iter(shape)
+
+    def done(self):
+        return self._done
+
+    @jit.unroll_safe
+    def next(self):
+        for w_idx in self.idx_w:
+            if w_idx is not None:
+                w_idx.next()
+        for i in range(self.shapelen - 1, -1, -1):
+            if self.indexes[i] < self.shape[i] - 1:
+                self.indexes[i] += 1
+                break
+            else:
+                self.indexes[i] = 0
+        else:
+            self._done = True
+
+    def get_index(self, space):
+        return space.newtuple([space.wrap(i) for i in self.indexes])
+
+def getitem_array_int(space, arr, res, iter_shape, indexes_w):
+    iter = PureShapeIterator(iter_shape, indexes_w)
+    while not iter.done():
+        # prepare the index
+        index_w = [None] * len(iter_shape)
+        for i in range(len(iter_shape)):
+            if iter.idx_w[i] is not None:
+                index_w[i] = iter.idx_w[i].getitem()
+            else:
+                index_w[i] = indexes_w[i]
+        res.descr_setitem(space, iter.get_index(space),
+                          arr.descr_getitem(space, space.newtuple(index_w)))
+        iter.next()
     return res
 
-def setitem_array_int(space, arr, iter_shape, indexes, val_arr):
-    assert len(indexes) == 1
-    assert len(iter_shape) == 1
-    index_iter = indexes[0].create_iter()
-    dtype = arr.get_dtype()
-    val_iter = val_arr.create_iter(iter_shape)
-    while not index_iter.done():
-        idx = space.int_w(index_iter.getitem())
-        arr.setitem(space, [idx], val_iter.getitem().convert_to(dtype))
-        val_iter.next()
-        index_iter.next()
+def setitem_array_int(space, arr, iter_shape, indexes_w, val_arr):
+    iter = PureShapeIterator(iter_shape, indexes_w)
+    while not iter.done():
+        # prepare the index
+        index_w = [None] * len(iter_shape)
+        for i in range(len(iter_shape)):
+            if iter.idx_w[i] is not None:
+                index_w[i] = iter.idx_w[i].getitem()
+            else:
+                index_w[i] = indexes_w[i]
+        arr.descr_setitem(space, space.newtuple(index_w),
+                          val_arr.descr_getitem(space, iter.get_index(space)))
+        iter.next()
+
diff --git a/pypy/module/micronumpy/test/test_numarray.py 
b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -1992,7 +1992,12 @@
         assert (a + a).item(1) == 4
         raises(IndexError, "array(5).item(1)")
         assert array([1]).item() == 1
- 
+
+    def test_int_array_index(self):
+        from _numpypy import array
+        a = array([[1, 2], [3, 4]])
+        b = a[array([0, 0])]
+        assert (b == [[1, 2], [1, 2]]).all()
 
 class AppTestSupport(BaseNumpyAppTest):
     def setup_class(cls):
_______________________________________________
pypy-commit mailing list
[email protected]
http://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to