Author: Ronan Lamy <[email protected]>
Branch: indexing
Changeset: r78533:723d0dacd7da
Date: 2015-07-12 21:20 +0200
http://bitbucket.org/pypy/pypy/changeset/723d0dacd7da/

Log:    Handle field access for record dtypes as a special case in
        getitem/setitem

diff --git a/pypy/module/micronumpy/arrayops.py 
b/pypy/module/micronumpy/arrayops.py
--- a/pypy/module/micronumpy/arrayops.py
+++ b/pypy/module/micronumpy/arrayops.py
@@ -162,8 +162,10 @@
         shape = [arr.get_shape()[0] * repeats]
         w_res = W_NDimArray.from_shape(space, shape, arr.get_dtype(), 
w_instance=arr)
         for i in range(repeats):
-            Chunks([Chunk(i, shape[0] - repeats + i, repeats,
-                 orig_size)]).apply(space, 
w_res).implementation.setslice(space, arr)
+            chunks = Chunks([Chunk(i, shape[0] - repeats + i, repeats,
+                 orig_size)])
+            view = chunks.apply(space, w_res)
+            view.implementation.setslice(space, arr)
     else:
         axis = space.int_w(w_axis)
         shape = arr.get_shape()[:]
@@ -174,7 +176,8 @@
         for i in range(repeats):
             chunks[axis] = Chunk(i, shape[axis] - repeats + i, repeats,
                                  orig_size)
-            Chunks(chunks).apply(space, w_res).implementation.setslice(space, 
arr)
+            view = Chunks(chunks).apply(space, w_res)
+            view.implementation.setslice(space, arr)
     return w_res
 
 
diff --git a/pypy/module/micronumpy/concrete.py 
b/pypy/module/micronumpy/concrete.py
--- a/pypy/module/micronumpy/concrete.py
+++ b/pypy/module/micronumpy/concrete.py
@@ -218,16 +218,10 @@
     @jit.unroll_safe
     def _prepare_slice_args(self, space, w_idx):
         if space.isinstance_w(w_idx, space.w_str):
-            idx = space.str_w(w_idx)
-            dtype = self.dtype
-            if not dtype.is_record():
-                raise oefmt(space.w_IndexError, "only integers, slices (`:`), "
-                    "ellipsis (`...`), numpy.newaxis (`None`) and integer or "
-                    "boolean arrays are valid indices")
-            elif idx not in dtype.fields:
-                raise oefmt(space.w_ValueError, "field named %s not found", 
idx)
-            return RecordChunk(idx)
-        elif (space.isinstance_w(w_idx, space.w_int) or
+            raise oefmt(space.w_IndexError, "only integers, slices (`:`), "
+                "ellipsis (`...`), numpy.newaxis (`None`) and integer or "
+                "boolean arrays are valid indices")
+        if (space.isinstance_w(w_idx, space.w_int) or
                 space.isinstance_w(w_idx, space.w_slice)):
             if len(self.get_shape()) == 0:
                 raise oefmt(space.w_ValueError, "cannot slice a 0-d array")
diff --git a/pypy/module/micronumpy/ndarray.py 
b/pypy/module/micronumpy/ndarray.py
--- a/pypy/module/micronumpy/ndarray.py
+++ b/pypy/module/micronumpy/ndarray.py
@@ -18,8 +18,9 @@
 from pypy.module.micronumpy.converters import multi_axis_converter, \
     order_converter, shape_converter, searchside_converter
 from pypy.module.micronumpy.flagsobj import W_FlagsObject
-from pypy.module.micronumpy.strides import get_shape_from_iterable, \
-    shape_agreement, shape_agreement_multiple, is_c_contiguous, is_f_contiguous
+from pypy.module.micronumpy.strides import (
+    get_shape_from_iterable, shape_agreement, shape_agreement_multiple,
+    is_c_contiguous, is_f_contiguous, RecordChunk)
 from pypy.module.micronumpy.casting import can_cast_array
 
 
@@ -203,6 +204,10 @@
                                prefix)
 
     def descr_getitem(self, space, w_idx):
+        if self.get_dtype().is_record():
+            if space.isinstance_w(w_idx, space.w_str):
+                idx = space.str_w(w_idx)
+                return self.getfield(space, idx)
         if space.is_w(w_idx, space.w_Ellipsis):
             return self
         elif isinstance(w_idx, W_NDimArray) and w_idx.get_dtype().is_bool():
@@ -229,6 +234,13 @@
         self.implementation.setitem_index(space, index_list, w_value)
 
     def descr_setitem(self, space, w_idx, w_value):
+        if self.get_dtype().is_record():
+            if space.isinstance_w(w_idx, space.w_str):
+                idx = space.str_w(w_idx)
+                view = self.getfield(space, idx)
+                w_value = convert_to_array(space, w_value)
+                view.implementation.setslice(space, w_value)
+                return
         if space.is_w(w_idx, space.w_Ellipsis):
             self.implementation.setslice(space, convert_to_array(space, 
w_value))
             return
@@ -241,6 +253,13 @@
         except ArrayArgumentException:
             self.setitem_array_int(space, w_idx, w_value)
 
+    def getfield(self, space, field):
+        dtype = self.get_dtype()
+        if field not in dtype.fields:
+            raise oefmt(space.w_ValueError, "field named %s not found", field)
+        chunks = RecordChunk(field)
+        return chunks.apply(space, self)
+
     def descr_delitem(self, space, w_idx):
         raise OperationError(space.w_ValueError, space.wrap(
             "cannot delete array elements"))
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to