Author: mattip <[email protected]>
Branch: 
Changeset: r74891:b75a06bcc48d
Date: 2014-12-11 18:36 +0200
http://bitbucket.org/pypy/pypy/changeset/b75a06bcc48d/

Log:    test, fix view of a slice

diff --git a/pypy/module/micronumpy/flagsobj.py 
b/pypy/module/micronumpy/flagsobj.py
--- a/pypy/module/micronumpy/flagsobj.py
+++ b/pypy/module/micronumpy/flagsobj.py
@@ -5,46 +5,26 @@
 from pypy.interpreter.gateway import interp2app
 from pypy.interpreter.typedef import TypeDef, GetSetProperty
 from pypy.module.micronumpy import constants as NPY
-
+from pypy.module.micronumpy.strides import is_c_contiguous, is_f_contiguous
 
 def enable_flags(arr, flags):
     arr.flags |= flags
 
-
 def clear_flags(arr, flags):
     arr.flags &= ~flags
 
-
[email protected]_safe
 def _update_contiguous_flags(arr):
-    shape = arr.shape
-    strides = arr.strides
-
-    is_c_contig = True
-    sd = arr.dtype.elsize
-    for i in range(len(shape) - 1, -1, -1):
-        dim = shape[i]
-        if strides[i] != sd:
-            is_c_contig = False
-            break
-        if dim == 0:
-            break
-        sd *= dim
+    is_c_contig = is_c_contiguous(arr)
     if is_c_contig:
         enable_flags(arr, NPY.ARRAY_C_CONTIGUOUS)
     else:
         clear_flags(arr, NPY.ARRAY_C_CONTIGUOUS)
 
-    sd = arr.dtype.elsize
-    for i in range(len(shape)):
-        dim = shape[i]
-        if strides[i] != sd:
-            clear_flags(arr, NPY.ARRAY_F_CONTIGUOUS)
-            return
-        if dim == 0:
-            break
-        sd *= dim
-    enable_flags(arr, NPY.ARRAY_F_CONTIGUOUS)
+    is_f_contig = is_f_contiguous(arr)
+    if is_f_contig:
+        enable_flags(arr, NPY.ARRAY_F_CONTIGUOUS)
+    else:
+        clear_flags(arr, NPY.ARRAY_F_CONTIGUOUS)
 
 
 class W_FlagsObject(W_Root):
diff --git a/pypy/module/micronumpy/ndarray.py 
b/pypy/module/micronumpy/ndarray.py
--- a/pypy/module/micronumpy/ndarray.py
+++ b/pypy/module/micronumpy/ndarray.py
@@ -19,7 +19,7 @@
     order_converter, shape_converter, searchside_converter
 from pypy.module.micronumpy.flagsobj import W_FlagsObject
 from pypy.module.micronumpy.strides import get_shape_from_iterable, \
-    shape_agreement, shape_agreement_multiple
+    shape_agreement, shape_agreement_multiple, is_c_contiguous, is_f_contiguous
 
 
 def _match_dot_shapes(space, left, right):
@@ -837,7 +837,15 @@
                 raise OperationError(space.w_ValueError, space.wrap(
                     "new type not compatible with array."))
         else:
-            if dims == 1 or impl.get_strides()[0] < impl.get_strides()[-1]:
+            if not is_c_contiguous(impl) and not is_f_contiguous(impl):
+                if old_itemsize != new_itemsize:
+                    raise OperationError(space.w_ValueError, space.wrap(
+                        "new type not compatible with array."))
+                # Strides, shape does not change
+                v = impl.astype(space, dtype)
+                return wrap_impl(space, w_type, self, v) 
+            strides = impl.get_strides()
+            if dims == 1 or strides[0] <strides[-1]:
                 # Column-major, resize first dimension
                 if new_shape[0] * old_itemsize % new_itemsize != 0:
                     raise OperationError(space.w_ValueError, space.wrap(
diff --git a/pypy/module/micronumpy/strides.py 
b/pypy/module/micronumpy/strides.py
--- a/pypy/module/micronumpy/strides.py
+++ b/pypy/module/micronumpy/strides.py
@@ -429,3 +429,35 @@
                     n_old_elems_to_use *= old_shape[oldI]
     assert len(new_strides) == len(new_shape)
     return new_strides[:]
+
[email protected]_safe
+def is_c_contiguous(arr):
+    shape = arr.get_shape()
+    strides = arr.get_strides()
+    ret = True
+    sd = arr.dtype.elsize
+    for i in range(len(shape) - 1, -1, -1):
+        dim = shape[i]
+        if strides[i] != sd:
+            ret = False
+            break
+        if dim == 0:
+            break
+        sd *= dim
+    return ret
+
[email protected]_safe
+def is_f_contiguous(arr):
+    shape = arr.get_shape()
+    strides = arr.get_strides()
+    ret = True
+    sd = arr.dtype.elsize
+    for i in range(len(shape)):
+        dim = shape[i]
+        if strides[i] != sd:
+            ret = False
+            break
+        if dim == 0:
+            break
+        sd *= dim
+    return ret
diff --git a/pypy/module/micronumpy/test/test_iterators.py 
b/pypy/module/micronumpy/test/test_iterators.py
--- a/pypy/module/micronumpy/test/test_iterators.py
+++ b/pypy/module/micronumpy/test/test_iterators.py
@@ -13,6 +13,11 @@
         self.strides = strides
         self.start = start
 
+    def get_shape(self):
+        return self.shape
+
+    def get_strides(self):
+        return self.strides
 
 class TestIterDirect(object):
     def test_iterator_basic(self):
diff --git a/pypy/module/micronumpy/test/test_ndarray.py 
b/pypy/module/micronumpy/test/test_ndarray.py
--- a/pypy/module/micronumpy/test/test_ndarray.py
+++ b/pypy/module/micronumpy/test/test_ndarray.py
@@ -1807,6 +1807,17 @@
         x = array([], dtype=[('a', 'int8'), ('b', 'int8')])
         y = x.view(dtype='int16')
 
+    def test_view_of_slice(self):
+        from numpy import empty
+        x = empty([6], 'uint32')
+        x.fill(0xdeadbeef)
+        s = x[::3]
+        exc = raises(ValueError, s.view, 'uint8')
+        assert exc.value[0] == 'new type not compatible with array.'
+        s[...] = 2
+        v = s.view(x.__class__)
+        assert (v == 2).all()
+    
     def test_tolist_scalar(self):
         from numpy import dtype
         int32 = dtype('int32').type
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to