Author: Ronan Lamy <ronan.l...@gmail.com>
Branch: PyBuffer
Changeset: r90842:912d9dbe8d1f
Date: 2017-03-28 14:53 +0100
http://bitbucket.org/pypy/pypy/changeset/912d9dbe8d1f/

Log:    Simplify _cast_to_ND()

diff --git a/pypy/objspace/std/memoryobject.py 
b/pypy/objspace/std/memoryobject.py
--- a/pypy/objspace/std/memoryobject.py
+++ b/pypy/objspace/std/memoryobject.py
@@ -546,7 +546,8 @@
                             "memoryview: cast must be 1D -> ND or ND -> 1D")
 
         origfmt = self.getformat()
-        mv = self._cast_to_1D(space, buf, origfmt, fmt)
+        newbuf = self._cast_to_1D(space, buf, origfmt, fmt)
+        mv = W_MemoryView(newbuf, newbuf.getformat(), newbuf.getitemsize())
         if w_shape:
             fview = space.fixedview(w_shape)
             shape = [space.int_w(w_obj) for w_obj in fview]
@@ -604,10 +605,7 @@
         if not newfmt:
             raise oefmt(space.w_RuntimeError,
                     "memoryview: internal error")
-        newbuf = BufferView1D(buf, newfmt, itemsize)
-        mv = W_MemoryView(newbuf, newbuf.getformat(), newbuf.getitemsize())
-        mv._init_flags()
-        return mv
+        return BufferView1D(buf, newfmt, itemsize)
 
     def get_native_fmtstr(self, fmt):
         lenfmt = len(fmt)
@@ -637,33 +635,31 @@
         return None
 
     def _cast_to_ND(self, space, shape, ndim):
-        self.ndim = ndim
-        length = self.itemsize
-        if ndim == 0:
-            self.shape = []
-            self.strides = []
-        else:
-            self.shape = shape
-            for i in range(ndim):
-                length *= shape[i]
-            self._init_strides_from_shape()
-
+        buf = self.buf
+        length = itemsize = buf.getitemsize()
+        for i in range(ndim):
+            length *= shape[i]
         if length != self.buf.getlength():
             raise oefmt(space.w_TypeError,
                         "memoryview: product(shape) * itemsize != buffer size")
 
+        self.ndim = ndim
+        self.shape = shape
+        self.strides = self._strides_from_shape(shape, itemsize)
         self._init_flags()
 
-    def _init_strides_from_shape(self):
-        shape = self.getshape()
-        s = [0] * len(shape)
-        self.strides = s
-        ndim = self.getndim()
-        s[ndim-1] = self.itemsize
-        i = ndim-2
+    @staticmethod
+    def _strides_from_shape(shape, itemsize):
+        ndim = len(shape)
+        if ndim == 0:
+            return []
+        s = [0] * ndim
+        s[ndim - 1] = itemsize
+        i = ndim - 2
         while i >= 0:
             s[i] = s[i+1] * shape[i+1]
             i -= 1
+        return s
 
     def descr_hex(self, space):
         from pypy.objspace.std.bytearrayobject import _array_to_hexstring
_______________________________________________
pypy-commit mailing list
pypy-commit@python.org
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to