Author: Maciej Fijalkowski <[email protected]>
Branch: missing-ndarray-attributes
Changeset: r60815:a719407593e5
Date: 2013-02-02 15:59 +0200
http://bitbucket.org/pypy/pypy/changeset/a719407593e5/

Log:    some fight with RPython

diff --git a/pypy/module/micronumpy/arrayimpl/sort.py 
b/pypy/module/micronumpy/arrayimpl/sort.py
--- a/pypy/module/micronumpy/arrayimpl/sort.py
+++ b/pypy/module/micronumpy/arrayimpl/sort.py
@@ -7,6 +7,8 @@
 from rpython.rlib.listsort import make_timsort_class
 from rpython.rlib.rawstorage import raw_storage_getitem, raw_storage_setitem, \
         free_raw_storage, alloc_raw_storage
+from rpython.rlib.unroll import unrolling_iterable
+from rpython.rlib.objectmodel import specialize
 from pypy.interpreter.error import OperationError
 from pypy.module.micronumpy.base import W_NDimArray
 from pypy.module.micronumpy import interp_dtype, types
@@ -14,10 +16,10 @@
 
 INT_SIZE = rffi.sizeof(lltype.Signed)
 
-def make_sort_classes(space, itemtype):
+def make_sort_function(space, itemtype):
     TP = itemtype.T
     
-    class ArgArrayRepresentation(object):
+    class Repr(object):
         def __init__(self, index_stride_size, stride_size, size, values,
                      indexes, index_start, start):
             self.index_stride_size = index_stride_size
@@ -41,14 +43,15 @@
                                 self.start, rffi.cast(TP, item[0]))
             raw_storage_setitem(self.indexes, idx * self.index_stride_size +
                                 self.index_start, item[1])
-    class ArgArrayRepWithStorage(ArgArrayRepresentation):
+
+    class ArgArrayRepWithStorage(Repr):
         def __init__(self, index_stride_size, stride_size, size):
             start = 0
             dtype = interp_dtype.get_dtype_cache(space).w_longdtype
             self.indexes = dtype.itemtype.malloc(size*dtype.get_size())
             self.values = alloc_raw_storage(size*rffi.sizeof(TP), 
track_allocation=False)
-            ArgArrayRepresentation.__init__(self, index_stride_size, 
stride_size, 
-                    size, self.values, self.indexes, start, start)
+            Repr.__init__(self, index_stride_size, stride_size, 
+                          size, self.values, self.indexes, start, start)
 
         def __del__(self):
             free_raw_storage(self.indexes, track_allocation=False)
@@ -76,62 +79,69 @@
     ArgSort = make_timsort_class(arg_getitem, arg_setitem, arg_length,
                                  arg_getitem_slice, arg_lt)
 
-    return ArgArrayRepresentation, ArgSort
+    def argsort(arr, space, w_axis, itemsize):
+        if w_axis is space.w_None:
+            # note that it's fine ot pass None here as we're not going
+            # to pass the result around (None is the link to base in slices)
+            arr = arr.reshape(space, None, [arr.get_size()])
+            axis = 0
+        elif w_axis is None:
+            axis = -1
+        else:
+            axis = space.int_w(w_axis)
+        # create array of indexes
+        dtype = interp_dtype.get_dtype_cache(space).w_longdtype
+        index_arr = W_NDimArray.from_shape(arr.get_shape(), dtype)
+        storage = index_arr.implementation.get_storage()
+        if len(arr.get_shape()) == 1:
+            for i in range(arr.get_size()):
+                raw_storage_setitem(storage, i * INT_SIZE, i)
+            r = Repr(INT_SIZE, itemsize, arr.get_size(), arr.get_storage(),
+                     storage, 0, arr.start)
+            ArgSort(r).sort()
+        else:
+            shape = arr.get_shape()
+            if axis < 0:
+                axis = len(shape) + axis - 1
+            if axis < 0 or axis > len(shape):
+                raise OperationError(space.w_IndexError, space.wrap(
+                                                    "Wrong axis %d" % axis))
+            iterable_shape = shape[:axis] + [0] + shape[axis + 1:]
+            iter = AxisIterator(arr, iterable_shape, axis, False)
+            index_impl = index_arr.implementation
+            index_iter = AxisIterator(index_impl, iterable_shape, axis, False)
+            stride_size = arr.strides[axis]
+            index_stride_size = index_impl.strides[axis]
+            axis_size = arr.shape[axis]
+            while not iter.done():
+                for i in range(axis_size):
+                    raw_storage_setitem(storage, i * index_stride_size +
+                                        index_iter.offset, i)
+                r = Repr(index_stride_size, stride_size, axis_size,
+                         arr.get_storage(), storage, index_iter.offset, 
iter.offset)
+                ArgSort(r).sort()
+                iter.next()
+                index_iter.next()
+        return index_arr
+
+    return argsort
 
 def argsort_array(arr, space, w_axis):
-    space.fromcache(SortCache) # that populates SortClasses
+    cache = space.fromcache(SortCache) # that populates SortClasses
     itemtype = arr.dtype.itemtype
-    if itemtype.Sort is None:
-        # XXX this should probably be changed
-        raise OperationError(space.w_NotImplementedError,
-               space.wrap("sorting of non-numeric types " + \
-                      "'%s' is not implemented" % arr.dtype.get_name(), ))
-    if w_axis is space.w_None:
-        # note that it's fine ot pass None here as we're not going
-        # to pass the result around (None is the link to base in slices)
-        arr = arr.reshape(space, None, [arr.get_size()])
-        axis = 0
-    elif w_axis is None:
-        axis = -1
-    else:
-        axis = space.int_w(w_axis)
-    Repr = itemtype.SortRepr
-    Sort = itemtype.Sort
-    itemsize = itemtype.get_element_size()
-    # create array of indexes
-    dtype = interp_dtype.get_dtype_cache(space).w_longdtype
-    index_arr = W_NDimArray.from_shape(arr.get_shape(), dtype)
-    storage = index_arr.implementation.get_storage()
-    if len(arr.get_shape()) == 1:
-        for i in range(arr.get_size()):
-            raw_storage_setitem(storage, i * INT_SIZE, i)
-        r = Repr(INT_SIZE, itemsize, arr.get_size(), arr.get_storage(),
-                 storage, 0, arr.start)
-        Sort(r).sort()
-    else:
-        shape = arr.get_shape()
-        if axis < 0:
-            axis = len(shape) + axis - 1
-        if axis < 0 or axis > len(shape):
-            raise OperationError(space.w_IndexError, space.wrap(
-                                                "Wrong axis %d" % axis))
-        iterable_shape = shape[:axis] + [0] + shape[axis + 1:]
-        iter = AxisIterator(arr, iterable_shape, axis, False)
-        index_impl = index_arr.implementation
-        index_iter = AxisIterator(index_impl, iterable_shape, axis, False)
-        stride_size = arr.strides[axis]
-        index_stride_size = index_impl.strides[axis]
-        axis_size = arr.shape[axis]
-        while not iter.done():
-            for i in range(axis_size):
-                raw_storage_setitem(storage, i * index_stride_size +
-                                    index_iter.offset, i)
-            r = Repr(index_stride_size, stride_size, axis_size,
-                     arr.get_storage(), storage, index_iter.offset, 
iter.offset)
-            Sort(r).sort()
-            iter.next()
-            index_iter.next()
-    return index_arr
+    for tp in all_types:
+        if isinstance(itemtype, tp):
+            return cache._lookup(tp)(arr, space, w_axis,
+                                     itemtype.get_element_size())
+    # XXX this should probably be changed
+    raise OperationError(space.w_NotImplementedError,
+           space.wrap("sorting of non-numeric types " + \
+                  "'%s' is not implemented" % arr.dtype.get_name(), ))
+
+all_types = (types.all_int_types + types.all_complex_types +
+             types.all_float_types)
+all_types = [i for i in all_types if not '_mixin_' in i.__dict__]
+all_types = unrolling_iterable(all_types)
 
 class SortCache(object):
     built = False
@@ -139,9 +149,9 @@
     def __init__(self, space):
         if self.built:
             return
-        for cls in types.all_float_types:
-            cls.SortRepr, cls.Sort = make_sort_classes(space, cls)
-        for cls in types.all_int_types:
-            cls.SortRepr, cls.Sort = make_sort_classes(space, cls)
-        for cls in types.all_complex_types:
-            cls.SortRepr, cls.Sort = make_sort_classes(space, cls)
+        self.built = True
+        cache = {}
+        for cls in all_types._items:
+            cache[cls] = make_sort_function(space, cls)
+        self.cache = cache
+        self._lookup = specialize.memo()(lambda tp : cache[tp])
diff --git a/pypy/module/micronumpy/test/test_zjit.py 
b/pypy/module/micronumpy/test/test_zjit.py
--- a/pypy/module/micronumpy/test/test_zjit.py
+++ b/pypy/module/micronumpy/test/test_zjit.py
@@ -476,24 +476,6 @@
                                 'jump': 1,
                                 'raw_store': 1})
     
-    def define_count_nonzero():
-        return """
-        a = [[0, 2, 3, 4], [5, 6, 0, 8], [9, 10, 11, 0]]
-        count_nonzero(a) 
-        """
-
-    def test_count_nonzero(self):
-        result = self.run("count_nonzero")
-        assert result == 9
-        self.check_simple_loop({'setfield_gc': 3, 
-                                'raw_load': 1, 
-                                'guard_false': 1, 
-                                'jump': 1, 
-                                'int_ge': 1, 
-                                'new_with_vtable': 1, 
-                                'int_add': 2, 
-                                'float_ne': 1})
-
     def define_argsort():
         return """
         a = |30|
_______________________________________________
pypy-commit mailing list
[email protected]
http://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to