Author: Maciej Fijalkowski <[email protected]>
Branch: missing-ndarray-attributes
Changeset: r58484:4b874f254549
Date: 2012-10-27 00:06 +0200
http://bitbucket.org/pypy/pypy/changeset/4b874f254549/

Log:    shuffle stuff around and pass argsort tests

diff --git a/pypy/module/micronumpy/arrayimpl/concrete.py 
b/pypy/module/micronumpy/arrayimpl/concrete.py
--- a/pypy/module/micronumpy/arrayimpl/concrete.py
+++ b/pypy/module/micronumpy/arrayimpl/concrete.py
@@ -1,6 +1,6 @@
 
 from pypy.module.micronumpy.arrayimpl import base
-from pypy.module.micronumpy import support, loop
+from pypy.module.micronumpy import support, loop, iter
 from pypy.module.micronumpy.base import convert_to_array, W_NDimArray,\
      ArrayArgumentException
 from pypy.module.micronumpy.strides import calc_new_strides, shape_agreement,\
@@ -12,152 +12,6 @@
 from pypy.rlib.rawstorage import free_raw_storage
 from pypy.module.micronumpy.arrayimpl.sort import argsort_array
 
-class ConcreteArrayIterator(base.BaseArrayIterator):
-    def __init__(self, array):
-        self.array = array
-        self.offset = 0
-        self.dtype = array.dtype
-        self.skip = self.dtype.itemtype.get_element_size()
-        self.size = array.size
-
-    def setitem(self, elem):
-        self.array.setitem(self.offset, elem)
-
-    def getitem(self):
-        return self.array.getitem(self.offset)
-
-    def getitem_bool(self):
-        return self.dtype.getitem_bool(self.array, self.offset)
-
-    def next(self):
-        self.offset += self.skip
-
-    def next_skip_x(self, x):
-        self.offset += self.skip * x
-
-    def done(self):
-        return self.offset >= self.size
-
-    def reset(self):
-        self.offset %= self.size
-
-class OneDimViewIterator(ConcreteArrayIterator):
-    def __init__(self, array):
-        self.array = array
-        self.offset = array.start
-        self.skip = array.strides[0]
-        self.dtype = array.dtype
-        self.index = 0
-        self.size = array.shape[0]
-
-    def next(self):
-        self.offset += self.skip
-        self.index += 1
-
-    def next_skip_x(self, x):
-        self.offset += self.skip * x
-        self.index += x
-
-    def done(self):
-        return self.index >= self.size
-
-    def reset(self):
-        self.offset %= self.size
-
-class MultiDimViewIterator(ConcreteArrayIterator):
-    def __init__(self, array, start, strides, backstrides, shape):
-        self.indexes = [0] * len(shape)
-        self.array = array
-        self.shape = shape
-        self.offset = start
-        self.shapelen = len(shape)
-        self._done = False
-        self.strides = strides
-        self.backstrides = backstrides
-        self.size = array.size
-
-    @jit.unroll_safe
-    def next(self):
-        offset = self.offset
-        for i in range(self.shapelen - 1, -1, -1):
-            if self.indexes[i] < self.shape[i] - 1:
-                self.indexes[i] += 1
-                offset += self.strides[i]
-                break
-            else:
-                self.indexes[i] = 0
-                offset -= self.backstrides[i]
-        else:
-            self._done = True
-        self.offset = offset
-
-    @jit.unroll_safe
-    def next_skip_x(self, step):
-        for i in range(len(self.shape) - 1, -1, -1):
-            if self.indexes[i] < self.shape[i] - step:
-                self.indexes[i] += step
-                self.offset += self.strides[i] * step
-                break
-            else:
-                remaining_step = (self.indexes[i] + step) // self.shape[i]
-                this_i_step = step - remaining_step * self.shape[i]
-                self.offset += self.strides[i] * this_i_step
-                self.indexes[i] = self.indexes[i] +  this_i_step
-                step = remaining_step
-        else:
-            self._done = True
-
-    def done(self):
-        return self._done
-
-    def reset(self):
-        self.offset %= self.size
-
-class AxisIterator(base.BaseArrayIterator):
-    def __init__(self, array, shape, dim):
-        self.shape = shape
-        strides = array.strides
-        backstrides = array.backstrides
-        if len(shape) == len(strides):
-            # keepdims = True
-            self.strides = strides[:dim] + [0] + strides[dim + 1:]
-            self.backstrides = backstrides[:dim] + [0] + backstrides[dim + 1:]
-        else:
-            self.strides = strides[:dim] + [0] + strides[dim:]
-            self.backstrides = backstrides[:dim] + [0] + backstrides[dim:]
-        self.first_line = True
-        self.indices = [0] * len(shape)
-        self._done = False
-        self.offset = array.start
-        self.dim = dim
-        self.array = array
-        
-    def setitem(self, elem):
-        self.array.setitem(self.offset, elem)
-
-    def getitem(self):
-        return self.array.getitem(self.offset)
-
-    @jit.unroll_safe
-    def next(self):
-        for i in range(len(self.shape) - 1, -1, -1):
-            if self.indices[i] < self.shape[i] - 1:
-                if i == self.dim:
-                    self.first_line = False
-                self.indices[i] += 1
-                self.offset += self.strides[i]
-                break
-            else:
-                if i == self.dim:
-                    self.first_line = True
-                self.indices[i] = 0
-                self.offset -= self.backstrides[i]
-        else:
-            self._done = True
-
-    def done(self):
-        return self._done
-
 def int_w(space, w_obj):
     try:
         return space.int_w(space.index(w_obj))
@@ -354,12 +208,12 @@
         return loop.setslice(self.shape, impl, self)
 
     def create_axis_iter(self, shape, dim):
-        return AxisIterator(self, shape, dim)
+        return iter.AxisIterator(self, shape, dim)
 
     def create_dot_iter(self, shape, skip):
         r = calculate_dot_strides(self.strides, self.backstrides,
                                   shape, skip)
-        return MultiDimViewIterator(self, self.start, r[0], r[1], shape)
+        return iter.MultiDimViewIterator(self, self.start, r[0], r[1], shape)
 
     def swapaxes(self, axis1, axis2):
         shape = self.shape[:]
@@ -389,10 +243,10 @@
 
     def create_iter(self, shape):
         if shape == self.shape:
-            return ConcreteArrayIterator(self)
+            return iter.ConcreteArrayIterator(self)
         r = calculate_broadcast_strides(self.strides, self.backstrides,
                                         self.shape, shape)
-        return MultiDimViewIterator(self, 0, r[0], r[1], shape)
+        return iter.MultiDimViewIterator(self, 0, r[0], r[1], shape)
 
     def fill(self, box):
         self.dtype.fill(self.storage, box, 0, self.size)
@@ -431,12 +285,12 @@
         if shape != self.shape:
             r = calculate_broadcast_strides(self.strides, self.backstrides,
                                             self.shape, shape)
-            return MultiDimViewIterator(self.parent,
+            return iter.MultiDimViewIterator(self.parent,
                                         self.start, r[0], r[1], shape)
         if len(self.shape) == 1:
-            return OneDimViewIterator(self)
-        return MultiDimViewIterator(self.parent, self.start, self.strides,
-                                    self.backstrides, self.shape)
+            return iter.OneDimViewIterator(self)
+        return iter.MultiDimViewIterator(self.parent, self.start, self.strides,
+                                         self.backstrides, self.shape)
 
     def set_shape(self, space, new_shape):
         if len(self.shape) < 2 or self.size == 0:
diff --git a/pypy/module/micronumpy/arrayimpl/sort.py 
b/pypy/module/micronumpy/arrayimpl/sort.py
--- a/pypy/module/micronumpy/arrayimpl/sort.py
+++ b/pypy/module/micronumpy/arrayimpl/sort.py
@@ -11,6 +11,7 @@
 from pypy.interpreter.error import OperationError
 from pypy.module.micronumpy.base import W_NDimArray
 from pypy.module.micronumpy import interp_dtype, types
+from pypy.module.micronumpy.iter import AxisIterator
 
 INT_SIZE = rffi.sizeof(lltype.Signed)
 
@@ -19,22 +20,30 @@
     TP = itemtype.T
     
     class ArgArrayRepresentation(object):
-        def __init__(self, itemsize, size, values, indexes):
-            self.itemsize = itemsize
+        def __init__(self, index_stride_size, stride_size, size, values,
+                     indexes, index_start, start):
+            self.index_stride_size = index_stride_size
+            self.stride_size = stride_size
+            self.index_start = index_start
+            self.start = start
             self.size = size
             self.values = values
             self.indexes = indexes
+            self.start = start
 
         def getitem(self, item):
-            v = raw_storage_getitem(TP, self.values, item * self.itemsize)
+            v = raw_storage_getitem(TP, self.values, item * self.stride_size
+                                    + self.start)
             v = itemtype.for_computation(v)
             return (v, raw_storage_getitem(lltype.Signed, self.indexes,
-                                           item * INT_SIZE))
+                                           item * self.index_stride_size +
+                                           self.index_start))
 
         def setitem(self, idx, item):
-            raw_storage_setitem(self.values, idx * self.itemsize,
-                                rffi.cast(TP, item[0]))
-            raw_storage_setitem(self.indexes, idx * INT_SIZE, item[1])
+            raw_storage_setitem(self.values, idx * self.stride_size +
+                                self.start, rffi.cast(TP, item[0]))
+            raw_storage_setitem(self.indexes, idx * self.index_stride_size +
+                                self.index_start, item[1])
 
     def arg_getitem(lst, item):
         return lst.getitem(item)
@@ -73,15 +82,36 @@
     itemsize = itemtype.get_element_size()
     # create array of indexes
     dtype = interp_dtype.get_dtype_cache(space).w_longdtype
-    indexes = W_NDimArray.from_shape(arr.get_shape(), dtype)
+    index_arr = W_NDimArray.from_shape(arr.get_shape(), dtype)
+    storage = index_arr.implementation.get_storage()
     if len(arr.get_shape()) == 1:
-        storage = indexes.implementation.get_storage()
         for i in range(arr.get_size()):
             raw_storage_setitem(storage, i * INT_SIZE, i)
         Repr, Sort = make_sort_classes(space, itemtype)
-        r = Repr(itemsize, arr.get_size(), arr.get_storage(),
-                 indexes.implementation.get_storage())
+        r = Repr(INT_SIZE, itemsize, arr.get_size(), arr.get_storage(),
+                 storage, 0, arr.start)
         Sort(r).sort()
     else:
-        xxx
-    return indexes
+        shape = arr.get_shape()
+        if axis < 0:
+            axis = len(shape) + axis - 1
+        if axis < 0 or axis > len(shape):
+            raise OperationError(space.w_IndexError("Wrong axis %d" % axis))
+        iterable_shape = shape[:axis] + [0] + shape[axis + 1:]
+        iter = AxisIterator(arr, iterable_shape, axis)
+        index_impl = index_arr.implementation
+        index_iter = AxisIterator(index_impl, iterable_shape, axis)
+        stride_size = arr.strides[axis]
+        index_stride_size = index_impl.strides[axis]
+        axis_size = arr.shape[axis]
+        Repr, Sort = make_sort_classes(space, itemtype)
+        while not iter.done():
+            for i in range(axis_size):
+                raw_storage_setitem(storage, i * index_stride_size +
+                                    index_iter.offset, i)
+            r = Repr(index_stride_size, stride_size, axis_size,
+                     arr.get_storage(), storage, index_iter.offset, 
iter.offset)
+            Sort(r).sort()
+            iter.next()
+            index_iter.next()
+    return index_arr
diff --git a/pypy/module/micronumpy/iter.py b/pypy/module/micronumpy/iter.py
--- a/pypy/module/micronumpy/iter.py
+++ b/pypy/module/micronumpy/iter.py
@@ -45,6 +45,7 @@
 from pypy.module.micronumpy.strides import enumerate_chunks,\
      calculate_slice_strides
 from pypy.module.micronumpy.base import W_NDimArray
+from pypy.module.micronumpy.arrayimpl import base
 from pypy.rlib import jit
 
 # structures to describe slicing
@@ -121,3 +122,181 @@
 class BroadcastTransform(BaseTransform):
     def __init__(self, res_shape):
         self.res_shape = res_shape
+
+class PureShapeIterator(object):
+    def __init__(self, shape, idx_w):
+        self.shape = shape
+        self.shapelen = len(shape)
+        self.indexes = [0] * len(shape)
+        self._done = False
+        self.idx_w = [None] * len(idx_w)
+        for i, w_idx in enumerate(idx_w):
+            if isinstance(w_idx, W_NDimArray):
+                self.idx_w[i] = w_idx.create_iter(shape)
+
+    def done(self):
+        return self._done
+
+    @jit.unroll_safe
+    def next(self):
+        for w_idx in self.idx_w:
+            if w_idx is not None:
+                w_idx.next()
+        for i in range(self.shapelen - 1, -1, -1):
+            if self.indexes[i] < self.shape[i] - 1:
+                self.indexes[i] += 1
+                break
+            else:
+                self.indexes[i] = 0
+        else:
+            self._done = True
+
+    @jit.unroll_safe
+    def get_index(self, space, shapelen):
+        return [space.wrap(self.indexes[i]) for i in range(shapelen)]
+
+class ConcreteArrayIterator(base.BaseArrayIterator):
+    def __init__(self, array):
+        self.array = array
+        self.offset = 0
+        self.dtype = array.dtype
+        self.skip = self.dtype.itemtype.get_element_size()
+        self.size = array.size
+
+    def setitem(self, elem):
+        self.array.setitem(self.offset, elem)
+
+    def getitem(self):
+        return self.array.getitem(self.offset)
+
+    def getitem_bool(self):
+        return self.dtype.getitem_bool(self.array, self.offset)
+
+    def next(self):
+        self.offset += self.skip
+
+    def next_skip_x(self, x):
+        self.offset += self.skip * x
+
+    def done(self):
+        return self.offset >= self.size
+
+    def reset(self):
+        self.offset %= self.size
+
+class OneDimViewIterator(ConcreteArrayIterator):
+    def __init__(self, array):
+        self.array = array
+        self.offset = array.start
+        self.skip = array.strides[0]
+        self.dtype = array.dtype
+        self.index = 0
+        self.size = array.shape[0]
+
+    def next(self):
+        self.offset += self.skip
+        self.index += 1
+
+    def next_skip_x(self, x):
+        self.offset += self.skip * x
+        self.index += x
+
+    def done(self):
+        return self.index >= self.size
+
+    def reset(self):
+        self.offset %= self.size
+
+class MultiDimViewIterator(ConcreteArrayIterator):
+    def __init__(self, array, start, strides, backstrides, shape):
+        self.indexes = [0] * len(shape)
+        self.array = array
+        self.shape = shape
+        self.offset = start
+        self.shapelen = len(shape)
+        self._done = False
+        self.strides = strides
+        self.backstrides = backstrides
+        self.size = array.size
+
+    @jit.unroll_safe
+    def next(self):
+        offset = self.offset
+        for i in range(self.shapelen - 1, -1, -1):
+            if self.indexes[i] < self.shape[i] - 1:
+                self.indexes[i] += 1
+                offset += self.strides[i]
+                break
+            else:
+                self.indexes[i] = 0
+                offset -= self.backstrides[i]
+        else:
+            self._done = True
+        self.offset = offset
+
+    @jit.unroll_safe
+    def next_skip_x(self, step):
+        for i in range(len(self.shape) - 1, -1, -1):
+            if self.indexes[i] < self.shape[i] - step:
+                self.indexes[i] += step
+                self.offset += self.strides[i] * step
+                break
+            else:
+                remaining_step = (self.indexes[i] + step) // self.shape[i]
+                this_i_step = step - remaining_step * self.shape[i]
+                self.offset += self.strides[i] * this_i_step
+                self.indexes[i] = self.indexes[i] +  this_i_step
+                step = remaining_step
+        else:
+            self._done = True
+
+    def done(self):
+        return self._done
+
+    def reset(self):
+        self.offset %= self.size
+
+class AxisIterator(base.BaseArrayIterator):
+    def __init__(self, array, shape, dim):
+        self.shape = shape
+        strides = array.strides
+        backstrides = array.backstrides
+        if len(shape) == len(strides):
+            # keepdims = True
+            self.strides = strides[:dim] + [0] + strides[dim + 1:]
+            self.backstrides = backstrides[:dim] + [0] + backstrides[dim + 1:]
+        else:
+            self.strides = strides[:dim] + [0] + strides[dim:]
+            self.backstrides = backstrides[:dim] + [0] + backstrides[dim:]
+        self.first_line = True
+        self.indices = [0] * len(shape)
+        self._done = False
+        self.offset = array.start
+        self.dim = dim
+        self.array = array
+        
+    def setitem(self, elem):
+        self.array.setitem(self.offset, elem)
+
+    def getitem(self):
+        return self.array.getitem(self.offset)
+
+    @jit.unroll_safe
+    def next(self):
+        for i in range(len(self.shape) - 1, -1, -1):
+            if self.indices[i] < self.shape[i] - 1:
+                if i == self.dim:
+                    self.first_line = False
+                self.indices[i] += 1
+                self.offset += self.strides[i]
+                break
+            else:
+                if i == self.dim:
+                    self.first_line = True
+                self.indices[i] = 0
+                self.offset -= self.backstrides[i]
+        else:
+            self._done = True
+
+    def done(self):
+        return self._done
diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -8,6 +8,7 @@
 from pypy.rlib import jit
 from pypy.rpython.lltypesystem import lltype, rffi
 from pypy.module.micronumpy.base import W_NDimArray
+from pypy.module.micronumpy.iter import PureShapeIterator
 
 call2_driver = jit.JitDriver(name='numpy_call2',
                              greens = ['shapelen', 'func', 'calc_dtype',
@@ -395,38 +396,6 @@
         iter.next()
     return builder.build()
 
-class PureShapeIterator(object):
-    def __init__(self, shape, idx_w):
-        self.shape = shape
-        self.shapelen = len(shape)
-        self.indexes = [0] * len(shape)
-        self._done = False
-        self.idx_w = [None] * len(idx_w)
-        for i, w_idx in enumerate(idx_w):
-            if isinstance(w_idx, W_NDimArray):
-                self.idx_w[i] = w_idx.create_iter(shape)
-
-    def done(self):
-        return self._done
-
-    @jit.unroll_safe
-    def next(self):
-        for w_idx in self.idx_w:
-            if w_idx is not None:
-                w_idx.next()
-        for i in range(self.shapelen - 1, -1, -1):
-            if self.indexes[i] < self.shape[i] - 1:
-                self.indexes[i] += 1
-                break
-            else:
-                self.indexes[i] = 0
-        else:
-            self._done = True
-
-    @jit.unroll_safe
-    def get_index(self, space, shapelen):
-        return [space.wrap(self.indexes[i]) for i in range(shapelen)]
-
 getitem_int_driver = jit.JitDriver(name = 'numpy_getitem_int',
                                    greens = ['shapelen', 'indexlen',
                                              'prefixlen', 'dtype'],
_______________________________________________
pypy-commit mailing list
[email protected]
http://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to