Author: Maciej Fijalkowski <fij...@gmail.com>
Branch: numpy-faster-setslice
Changeset: r50808:b0190f46f44c
Date: 2011-12-21 20:55 +0200
http://bitbucket.org/pypy/pypy/changeset/b0190f46f44c/

Log:    Implement fast slice setting using memcpy

diff --git a/pypy/module/micronumpy/interp_iter.py 
b/pypy/module/micronumpy/interp_iter.py
--- a/pypy/module/micronumpy/interp_iter.py
+++ b/pypy/module/micronumpy/interp_iter.py
@@ -102,3 +102,27 @@
 class ConstantIterator(BaseIterator):
     def next(self, shapelen):
         return self
+
+# ------ other iterators that are not part of the computation frame ----------
+
+class AxisIterator(object):
+    """ This object will return offsets of each start of the last stride
+    """
+    def __init__(self, arr):
+        self.arr = arr
+        self.indices = [0] * (len(arr.shape) - 1)
+        self.done = False
+        self.offset = arr.start
+
+    def next(self):
+        for i in range(len(self.arr.shape) - 2, -1, -1):
+            if self.indices[i] < self.arr.shape[i] - 1:
+                self.indices[i] += 1
+                self.offset += self.arr.strides[i]
+                break
+            else:
+                self.indices[i] = 0
+                self.offset -= self.arr.backstrides[i]
+        else:
+            self.done = True
+        
diff --git a/pypy/module/micronumpy/interp_numarray.py 
b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -9,7 +9,7 @@
 from pypy.tool.sourcetools import func_with_new_name
 from pypy.rlib.rstring import StringBuilder
 from pypy.module.micronumpy.interp_iter import ArrayIterator,\
-     view_iter_from_arr, OneDimIterator
+     view_iter_from_arr, OneDimIterator, AxisIterator
 
 numpy_driver = jit.JitDriver(
     greens=['shapelen', 'sig'],
@@ -606,6 +606,9 @@
                                                        space.w_False]))
         return w_d
 
+    def supports_fast_slicing(self):
+        return False
+
 def convert_to_array(space, w_obj):
     if isinstance(w_obj, BaseArray):
         return w_obj
@@ -790,6 +793,9 @@
     def get_concrete(self):
         return self
 
+    def supports_fast_slicing(self):
+        return self.order == 'C' and self.strides[-1] == 1
+
     def find_dtype(self):
         return self.dtype
 
@@ -961,7 +967,33 @@
 
     def setslice(self, space, w_value):
         res_shape = shape_agreement(space, self.shape, w_value.shape)
-        self._sliceloop(w_value, res_shape)
+        if (res_shape == w_value.shape and self.supports_fast_slicing() and
+            w_value.supports_fast_slicing() and
+            self.dtype is w_value.find_dtype()):
+            self._fast_setslice(space, w_value)
+        else:
+            self._sliceloop(w_value, res_shape)
+
+    def _fast_setslice(self, space, w_value):
+        assert isinstance(w_value, ConcreteArray)
+        itemsize = self.dtype.itemtype.get_element_size()
+        if len(self.shape) == 1:
+            rffi.c_memcpy(
+                rffi.ptradd(self.storage, self.start * itemsize),
+                rffi.ptradd(w_value.storage, w_value.start * itemsize),
+                self.size * itemsize
+            )
+        else:
+            dest = AxisIterator(self)
+            source = AxisIterator(w_value)
+            while not dest.done:
+                rffi.c_memcpy(
+                    rffi.ptradd(self.storage, dest.offset * itemsize),
+                    rffi.ptradd(w_value.storage, source.offset * itemsize),
+                    self.shape[0] * itemsize
+                )
+                source.next()
+                dest.next()
 
     def _sliceloop(self, source, res_shape):
         sig = source.find_sig(res_shape)
diff --git a/pypy/module/micronumpy/test/test_numarray.py 
b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -1077,6 +1077,17 @@
         a = ones((1, 2, 3))
         assert a[0, 1, 2] == 1.0
 
+    def test_multidim_setslice(self):
+        from numpypy import zeros, ones
+        a = zeros((3, 3))
+        b = ones((3, 3))
+        a[:,1:3] = b[:,1:3]
+        assert (a == [[0, 1, 1], [0, 1, 1], [0, 1, 1]]).all()
+        a = zeros((3, 3))
+        b = ones((3, 3))
+        a[:,::2] = b[:,::2]
+        assert (a == [[1, 0, 1], [1, 0, 1], [1, 0, 1]]).all()
+
     def test_broadcast_ufunc(self):
         from numpypy import array
         a = array([[1, 2], [3, 4], [5, 6]])
_______________________________________________
pypy-commit mailing list
pypy-commit@python.org
http://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to