Author: mattip <[email protected]>
Branch:
Changeset: r74891:b75a06bcc48d
Date: 2014-12-11 18:36 +0200
http://bitbucket.org/pypy/pypy/changeset/b75a06bcc48d/
Log: test, fix view of a slice
diff --git a/pypy/module/micronumpy/flagsobj.py
b/pypy/module/micronumpy/flagsobj.py
--- a/pypy/module/micronumpy/flagsobj.py
+++ b/pypy/module/micronumpy/flagsobj.py
@@ -5,46 +5,26 @@
from pypy.interpreter.gateway import interp2app
from pypy.interpreter.typedef import TypeDef, GetSetProperty
from pypy.module.micronumpy import constants as NPY
-
+from pypy.module.micronumpy.strides import is_c_contiguous, is_f_contiguous
def enable_flags(arr, flags):
arr.flags |= flags
-
def clear_flags(arr, flags):
arr.flags &= ~flags
-
[email protected]_safe
def _update_contiguous_flags(arr):
- shape = arr.shape
- strides = arr.strides
-
- is_c_contig = True
- sd = arr.dtype.elsize
- for i in range(len(shape) - 1, -1, -1):
- dim = shape[i]
- if strides[i] != sd:
- is_c_contig = False
- break
- if dim == 0:
- break
- sd *= dim
+ is_c_contig = is_c_contiguous(arr)
if is_c_contig:
enable_flags(arr, NPY.ARRAY_C_CONTIGUOUS)
else:
clear_flags(arr, NPY.ARRAY_C_CONTIGUOUS)
- sd = arr.dtype.elsize
- for i in range(len(shape)):
- dim = shape[i]
- if strides[i] != sd:
- clear_flags(arr, NPY.ARRAY_F_CONTIGUOUS)
- return
- if dim == 0:
- break
- sd *= dim
- enable_flags(arr, NPY.ARRAY_F_CONTIGUOUS)
+ is_f_contig = is_f_contiguous(arr)
+ if is_f_contig:
+ enable_flags(arr, NPY.ARRAY_F_CONTIGUOUS)
+ else:
+ clear_flags(arr, NPY.ARRAY_F_CONTIGUOUS)
class W_FlagsObject(W_Root):
diff --git a/pypy/module/micronumpy/ndarray.py
b/pypy/module/micronumpy/ndarray.py
--- a/pypy/module/micronumpy/ndarray.py
+++ b/pypy/module/micronumpy/ndarray.py
@@ -19,7 +19,7 @@
order_converter, shape_converter, searchside_converter
from pypy.module.micronumpy.flagsobj import W_FlagsObject
from pypy.module.micronumpy.strides import get_shape_from_iterable, \
- shape_agreement, shape_agreement_multiple
+ shape_agreement, shape_agreement_multiple, is_c_contiguous, is_f_contiguous
def _match_dot_shapes(space, left, right):
@@ -837,7 +837,15 @@
raise OperationError(space.w_ValueError, space.wrap(
"new type not compatible with array."))
else:
- if dims == 1 or impl.get_strides()[0] < impl.get_strides()[-1]:
+ if not is_c_contiguous(impl) and not is_f_contiguous(impl):
+ if old_itemsize != new_itemsize:
+ raise OperationError(space.w_ValueError, space.wrap(
+ "new type not compatible with array."))
+ # Strides, shape does not change
+ v = impl.astype(space, dtype)
+ return wrap_impl(space, w_type, self, v)
+ strides = impl.get_strides()
+ if dims == 1 or strides[0] <strides[-1]:
# Column-major, resize first dimension
if new_shape[0] * old_itemsize % new_itemsize != 0:
raise OperationError(space.w_ValueError, space.wrap(
diff --git a/pypy/module/micronumpy/strides.py
b/pypy/module/micronumpy/strides.py
--- a/pypy/module/micronumpy/strides.py
+++ b/pypy/module/micronumpy/strides.py
@@ -429,3 +429,35 @@
n_old_elems_to_use *= old_shape[oldI]
assert len(new_strides) == len(new_shape)
return new_strides[:]
+
[email protected]_safe
+def is_c_contiguous(arr):
+ shape = arr.get_shape()
+ strides = arr.get_strides()
+ ret = True
+ sd = arr.dtype.elsize
+ for i in range(len(shape) - 1, -1, -1):
+ dim = shape[i]
+ if strides[i] != sd:
+ ret = False
+ break
+ if dim == 0:
+ break
+ sd *= dim
+ return ret
+
[email protected]_safe
+def is_f_contiguous(arr):
+ shape = arr.get_shape()
+ strides = arr.get_strides()
+ ret = True
+ sd = arr.dtype.elsize
+ for i in range(len(shape)):
+ dim = shape[i]
+ if strides[i] != sd:
+ ret = False
+ break
+ if dim == 0:
+ break
+ sd *= dim
+ return ret
diff --git a/pypy/module/micronumpy/test/test_iterators.py
b/pypy/module/micronumpy/test/test_iterators.py
--- a/pypy/module/micronumpy/test/test_iterators.py
+++ b/pypy/module/micronumpy/test/test_iterators.py
@@ -13,6 +13,11 @@
self.strides = strides
self.start = start
+ def get_shape(self):
+ return self.shape
+
+ def get_strides(self):
+ return self.strides
class TestIterDirect(object):
def test_iterator_basic(self):
diff --git a/pypy/module/micronumpy/test/test_ndarray.py
b/pypy/module/micronumpy/test/test_ndarray.py
--- a/pypy/module/micronumpy/test/test_ndarray.py
+++ b/pypy/module/micronumpy/test/test_ndarray.py
@@ -1807,6 +1807,17 @@
x = array([], dtype=[('a', 'int8'), ('b', 'int8')])
y = x.view(dtype='int16')
+ def test_view_of_slice(self):
+ from numpy import empty
+ x = empty([6], 'uint32')
+ x.fill(0xdeadbeef)
+ s = x[::3]
+ exc = raises(ValueError, s.view, 'uint8')
+ assert exc.value[0] == 'new type not compatible with array.'
+ s[...] = 2
+ v = s.view(x.__class__)
+ assert (v == 2).all()
+
def test_tolist_scalar(self):
from numpy import dtype
int32 = dtype('int32').type
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit