Author: mattip <[email protected]>
Branch: 
Changeset: r82688:403a0931e0bc
Date: 2016-03-04 11:58 +0200
http://bitbucket.org/pypy/pypy/changeset/403a0931e0bc/

Log:    merge ndarray-setitem-filtered, which fixes issue #1674, issue #1717

diff --git a/pypy/module/micronumpy/concrete.py 
b/pypy/module/micronumpy/concrete.py
--- a/pypy/module/micronumpy/concrete.py
+++ b/pypy/module/micronumpy/concrete.py
@@ -298,7 +298,14 @@
         except IndexError:
             # not a single result
             chunks = self._prepare_slice_args(space, w_index)
-            return new_view(space, orig_arr, chunks)
+            copy = False
+            if isinstance(chunks[0], BooleanChunk):
+                # numpy compatibility
+                copy = True
+            w_ret = new_view(space, orig_arr, chunks)
+            if copy:
+                w_ret = w_ret.descr_copy(space, space.wrap(w_ret.get_order()))
+            return w_ret
 
     def descr_setitem(self, space, orig_arr, w_index, w_value):
         try:
diff --git a/pypy/module/micronumpy/ndarray.py 
b/pypy/module/micronumpy/ndarray.py
--- a/pypy/module/micronumpy/ndarray.py
+++ b/pypy/module/micronumpy/ndarray.py
@@ -22,7 +22,8 @@
 from pypy.module.micronumpy.flagsobj import W_FlagsObject
 from pypy.module.micronumpy.strides import (
     get_shape_from_iterable, shape_agreement, shape_agreement_multiple,
-    is_c_contiguous, is_f_contiguous, calc_strides, new_view)
+    is_c_contiguous, is_f_contiguous, calc_strides, new_view, BooleanChunk,
+    SliceChunk)
 from pypy.module.micronumpy.casting import can_cast_array
 from pypy.module.micronumpy.descriptor import get_dtype_cache
 
@@ -204,7 +205,13 @@
         if iter_shape is None:
             # w_index is a list of slices, return a view
             chunks = self.implementation._prepare_slice_args(space, w_index)
-            return new_view(space, self, chunks)
+            copy = False
+            if isinstance(chunks[0], BooleanChunk):
+                copy = True
+            w_ret = new_view(space, self, chunks)
+            if copy:
+                w_ret = w_ret.descr_copy(space, space.wrap(w_ret.get_order()))
+            return w_ret
         shape = res_shape + self.get_shape()[len(indexes):]
         w_res = W_NDimArray.from_shape(space, shape, self.get_dtype(),
                                        self.get_order(), w_instance=self)
@@ -220,8 +227,24 @@
         if iter_shape is None:
             # w_index is a list of slices
             chunks = self.implementation._prepare_slice_args(space, w_index)
-            view = new_view(space, self, chunks)
-            view.implementation.setslice(space, val_arr)
+            dim = -1
+            view = self
+            for i, c in enumerate(chunks):
+                if isinstance(c, BooleanChunk):
+                    dim = i
+                    idx = c.w_idx
+                    chunks.pop(i)
+                    chunks.insert(0, SliceChunk(space.newslice(space.wrap(0), 
+                                 space.w_None, space.w_None)))
+                    break
+            if dim > 0:
+                view = self.implementation.swapaxes(space, self, 0, dim) 
+            if dim >= 0:
+                view = new_view(space, self, chunks)
+                view.setitem_filter(space, idx, val_arr)
+            else:
+                view = new_view(space, self, chunks)
+                view.implementation.setslice(space, val_arr)
             return
         if support.product(iter_shape) == 0:
             return
diff --git a/pypy/module/micronumpy/strides.py 
b/pypy/module/micronumpy/strides.py
--- a/pypy/module/micronumpy/strides.py
+++ b/pypy/module/micronumpy/strides.py
@@ -97,22 +97,19 @@
         # filter by axis dim
         filtr = chunks[dim]
         assert isinstance(filtr, BooleanChunk) 
+        # XXX this creates a new array, and fails in setitem
         w_arr = w_arr.getitem_filter(space, filtr.w_idx, axis=dim)
         arr = w_arr.implementation
         chunks[dim] = SliceChunk(space.newslice(space.wrap(0), 
-                                 space.wrap(-1), space.w_None))
+                                 space.w_None, space.w_None))
         r = calculate_slice_strides(space, arr.shape, arr.start,
                  arr.get_strides(), arr.get_backstrides(), chunks)
     else:
         r = calculate_slice_strides(space, arr.shape, arr.start,
                      arr.get_strides(), arr.get_backstrides(), chunks)
     shape, start, strides, backstrides = r
-    w_ret = W_NDimArray.new_slice(space, start, strides[:], backstrides[:],
+    return W_NDimArray.new_slice(space, start, strides[:], backstrides[:],
                                  shape[:], arr, w_arr)
-    if dim == 0:
-        # Do not return a view
-        return w_ret.descr_copy(space, space.wrap(w_ret.get_order()))
-    return w_ret
 
 @jit.unroll_safe
 def _extend_shape(old_shape, chunks):
diff --git a/pypy/module/micronumpy/test/test_ndarray.py 
b/pypy/module/micronumpy/test/test_ndarray.py
--- a/pypy/module/micronumpy/test/test_ndarray.py
+++ b/pypy/module/micronumpy/test/test_ndarray.py
@@ -2550,8 +2550,10 @@
         assert b.base is None
         b = a[:, np.array([True, False, True])]
         assert b.base is not None
+        a[np.array([True, False]), 0] = 100
         b = a[np.array([True, False]), 0]
-        assert (b ==[0]).all()
+        assert b.shape == (1,)
+        assert (b ==[100]).all()
 
     def test_scalar_indexing(self):
         import numpy as np
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to