Author: Maciej Fijalkowski <[email protected]>
Branch: numpy-indexing-by-arrays-2
Changeset: r51373:bec89443f75e
Date: 2012-01-17 10:55 +0200
http://bitbucket.org/pypy/pypy/changeset/bec89443f75e/

Log:    indexing by bool arrays of the same shape, step 1

diff --git a/pypy/module/micronumpy/interp_dtype.py 
b/pypy/module/micronumpy/interp_dtype.py
--- a/pypy/module/micronumpy/interp_dtype.py
+++ b/pypy/module/micronumpy/interp_dtype.py
@@ -46,6 +46,10 @@
     def getitem(self, storage, i):
         return self.itemtype.read(storage, self.itemtype.get_element_size(), 
i, 0)
 
+    def getitem_bool(self, storage, i):
+        isize = self.itemtype.get_element_size()
+        return self.itemtype.read_raw(storage, isize, i, 0, bool)
+
     def setitem(self, storage, i, box):
         self.itemtype.store(storage, self.itemtype.get_element_size(), i, 0, 
box)
 
diff --git a/pypy/module/micronumpy/interp_iter.py 
b/pypy/module/micronumpy/interp_iter.py
--- a/pypy/module/micronumpy/interp_iter.py
+++ b/pypy/module/micronumpy/interp_iter.py
@@ -62,11 +62,18 @@
         self.size = size
 
     def next(self, shapelen):
+        return self._next(1)
+
+    def _next(self, ofs):
         arr = instantiate(ArrayIterator)
         arr.size = self.size
-        arr.offset = self.offset + 1
+        arr.offset = self.offset + ofs
         return arr
 
+    def next_no_increase(self, shapelen):
+        # a hack to make JIT believe this is always virtual
+        return self._next(0)
+
     def done(self):
         return self.offset >= self.size
 
diff --git a/pypy/module/micronumpy/interp_numarray.py 
b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -483,7 +483,41 @@
         return [Chunk(*space.decode_index4(w_item, self.shape[i])) for i, 
w_item in
                 enumerate(space.fixedview(w_idx))]
 
+    def count_all_true(self, arr):
+        sig = arr.find_sig()
+        frame = sig.create_frame(self)
+        shapelen = len(arr.shape)
+        s = 0
+        while not frame.done():
+            iter = frame.get_final_iter()
+            s += arr.dtype.getitem_bool(arr.storage, iter.offset)
+            frame.next(shapelen)
+        return s
+
+    def getitem_filter(self, space, arr):
+        concr = arr.get_concrete()
+        size = self.count_all_true(concr)
+        res = W_NDimArray(size, [size], self.find_dtype())
+        ri = ArrayIterator(size)
+        shapelen = len(self.shape)
+        argi = ArrayIterator(concr.size)
+        sig = self.find_sig()
+        frame = sig.create_frame(self)
+        while not frame.done():
+            if concr.dtype.getitem_bool(concr.storage, argi.offset):
+                v = sig.eval(frame, self)
+                res.setitem(ri.offset, v)
+                ri = ri.next(1)
+            else:
+                ri = ri.next_no_increase(1)
+            argi = argi.next(shapelen)
+            frame.next(shapelen)
+        return res
+
     def descr_getitem(self, space, w_idx):
+        if (isinstance(w_idx, BaseArray) and w_idx.shape == self.shape and
+            w_idx.find_dtype().is_bool_type()):
+            return self.getitem_filter(space, w_idx)
         if self._single_item_result(space, w_idx):
             concrete = self.get_concrete()
             item = concrete._index_of_single_item(space, w_idx)
@@ -716,8 +750,7 @@
                                          frame=frame,
                                          ri=ri,
                                          self=self, result=result)
-            result.dtype.setitem(result.storage, ri.offset,
-                                 sig.eval(frame, self))
+            result.setitem(ri.offset, sig.eval(frame, self))
             frame.next(shapelen)
             ri = ri.next(shapelen)
         return result
diff --git a/pypy/module/micronumpy/test/test_numarray.py 
b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -1312,6 +1312,13 @@
         raises(IndexError,'arange(3)[array([-15])]')
         assert arange(3)[array(1)] == 1
 
+    def test_array_indexing_bool(self):
+        from _numpypy import arange
+        a = arange(10)
+        assert (a[a > 3] == [4, 5, 6, 7, 8, 9]).all()
+        a = arange(10).reshape(5, 2)
+        assert (a[a > 3] == [4, 5, 6, 7, 8, 9]).all()
+
 class AppTestSupport(BaseNumpyAppTest):
     def setup_class(cls):
         import struct
diff --git a/pypy/module/micronumpy/types.py b/pypy/module/micronumpy/types.py
--- a/pypy/module/micronumpy/types.py
+++ b/pypy/module/micronumpy/types.py
@@ -94,6 +94,11 @@
             width, storage, i, offset
         ))
 
+    @specialize.arg(5)
+    def read_raw(self, storage, width, i, offset, tp):
+        return libffi.array_getitem(clibffi.cast_type_to_ffitype(self.T),
+                                    width, storage, i, offset)
+
     def store(self, storage, width, i, offset, box):
         value = self.unbox(box)
         libffi.array_setitem(clibffi.cast_type_to_ffitype(self.T),
@@ -436,4 +441,4 @@
 class Float64(BaseType, Float):
     T = rffi.DOUBLE
     BoxType = interp_boxes.W_Float64Box
-    format_code = "d"
\ No newline at end of file
+    format_code = "d"
_______________________________________________
pypy-commit mailing list
[email protected]
http://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to