Author: Matti Picus <[email protected]>
Branch: ndarray-sort
Changeset: r67263:83de01401b30
Date: 2013-10-09 20:24 +0300
http://bitbucket.org/pypy/pypy/changeset/83de01401b30/
Log: some tests pass, raise for non-native byte order
diff --git a/pypy/module/micronumpy/arrayimpl/concrete.py
b/pypy/module/micronumpy/arrayimpl/concrete.py
--- a/pypy/module/micronumpy/arrayimpl/concrete.py
+++ b/pypy/module/micronumpy/arrayimpl/concrete.py
@@ -356,6 +356,10 @@
from pypy.module.micronumpy.arrayimpl.sort import argsort_array
return argsort_array(self, space, w_axis)
+ def sort(self, space, w_axis, w_order):
+ from pypy.module.micronumpy.arrayimpl.sort import sort_array
+ return sort_array(self, space, w_axis, w_order)
+
def base(self):
return None
diff --git a/pypy/module/micronumpy/arrayimpl/sort.py
b/pypy/module/micronumpy/arrayimpl/sort.py
--- a/pypy/module/micronumpy/arrayimpl/sort.py
+++ b/pypy/module/micronumpy/arrayimpl/sort.py
@@ -17,7 +17,7 @@
INT_SIZE = rffi.sizeof(lltype.Signed)
-def make_sort_function(space, itemtype, comp_type, count=1):
+def make_argsort_function(space, itemtype, comp_type, count=1):
TP = itemtype.T
step = rffi.sizeof(TP)
@@ -162,7 +162,7 @@
return argsort
def argsort_array(arr, space, w_axis):
- cache = space.fromcache(SortCache) # that populates SortClasses
+ cache = space.fromcache(ArgSortCache) # that populates ArgSortClasses
itemtype = arr.dtype.itemtype
for tp in all_types:
if isinstance(itemtype, tp[0]):
@@ -178,6 +178,161 @@
all_types = [i for i in all_types if not '_mixin_' in i[0].__dict__]
all_types = unrolling_iterable(all_types)
+def make_sort_function(space, itemtype, comp_type, count=1):
+ TP = itemtype.T
+ step = rffi.sizeof(TP)
+
+ class Repr(object):
+ def __init__(self, stride_size, size, values, start):
+ self.stride_size = stride_size
+ self.start = start
+ self.size = size
+ self.values = values
+
+ def getitem(self, item):
+ if count < 2:
+ v = raw_storage_getitem(TP, self.values, item *
self.stride_size
+ + self.start)
+ else:
+ v = []
+ for i in range(count):
+ _v = raw_storage_getitem(TP, self.values, item *
self.stride_size
+ + self.start + step * i)
+ v.append(_v)
+ if comp_type == 'int':
+ v = intmask(v)
+ elif comp_type == 'float':
+ v = float(v)
+ elif comp_type == 'complex':
+ v = [float(v[0]),float(v[1])]
+ else:
+ raise NotImplementedError('cannot reach')
+ return (v)
+
+ def setitem(self, idx, item):
+ if count < 2:
+ raw_storage_setitem(self.values, idx * self.stride_size +
+ self.start, rffi.cast(TP, item))
+ else:
+ i = 0
+ for val in item:
+ raw_storage_setitem(self.values, idx * self.stride_size +
+ self.start + i*step, rffi.cast(TP, val))
+ i += 1
+
+ class ArgArrayRepWithStorage(Repr):
+ def __init__(self, stride_size, size):
+ start = 0
+ values = alloc_raw_storage(size * stride_size,
+ track_allocation=False)
+ Repr.__init__(self, stride_size,
+ size, values, start)
+
+ def __del__(self):
+ free_raw_storage(self.values, track_allocation=False)
+
+ def arg_getitem(lst, item):
+ return lst.getitem(item)
+
+ def arg_setitem(lst, item, value):
+ lst.setitem(item, value)
+
+ def arg_length(lst):
+ return lst.size
+
+ def arg_getitem_slice(lst, start, stop):
+ retval = ArgArrayRepWithStorage(lst.stride_size, stop-start)
+ for i in range(stop-start):
+ retval.setitem(i, lst.getitem(i+start))
+ return retval
+
+ if count < 2:
+ def arg_lt(a, b):
+ # handles NAN and INF
+ return a < b or b != b and a == a
+ else:
+ def arg_lt(a, b):
+ for i in range(count):
+ if a[i] < b[i] or b != b and a == a:
+ return True
+ elif a[i] > b[i]:
+ return False
+ # Does numpy do True?
+ return False
+
+ ArgSort = make_timsort_class(arg_getitem, arg_setitem, arg_length,
+ arg_getitem_slice, arg_lt)
+
+ def sort(arr, space, w_axis, itemsize):
+ if w_axis is space.w_None:
+ # note that it's fine ot pass None here as we're not going
+ # to pass the result around (None is the link to base in slices)
+ arr = arr.reshape(space, None, [arr.get_size()])
+ axis = 0
+ elif w_axis is None:
+ axis = -1
+ else:
+ axis = space.int_w(w_axis)
+ # create array of indexes
+ if len(arr.get_shape()) == 1:
+ r = Repr(itemsize, arr.get_size(), arr.get_storage(),
+ arr.start)
+ ArgSort(r).sort()
+ else:
+ shape = arr.get_shape()
+ if axis < 0:
+ axis = len(shape) + axis - 1
+ if axis < 0 or axis > len(shape):
+ raise OperationError(space.w_IndexError, space.wrap(
+ "Wrong axis %d" % axis))
+ iterable_shape = shape[:axis] + [0] + shape[axis + 1:]
+ iter = AxisIterator(arr, iterable_shape, axis, False)
+ stride_size = arr.strides[axis]
+ axis_size = arr.shape[axis]
+ while not iter.done():
+ r = Repr(stride_size, axis_size, arr.get_storage(),
iter.offset)
+ ArgSort(r).sort()
+ iter.next()
+
+ return sort
+
+def sort_array(arr, space, w_axis, w_order):
+ cache = space.fromcache(SortCache) # that populates SortClasses
+ itemtype = arr.dtype.itemtype
+ if not arr.dtype.native:
+ raise OperationError(space.w_NotImplementedError,
+ space.wrap("sorting of non-native btyeorder not supported yet"))
+ for tp in all_types:
+ if isinstance(itemtype, tp[0]):
+ return cache._lookup(tp)(arr, space, w_axis,
+ itemtype.get_element_size())
+ # XXX this should probably be changed
+ raise OperationError(space.w_NotImplementedError,
+ space.wrap("sorting of non-numeric types " + \
+ "'%s' is not implemented" % arr.dtype.get_name(), ))
+
+all_types = (types.all_float_types + types.all_complex_types +
+ types.all_int_types)
+all_types = [i for i in all_types if not '_mixin_' in i[0].__dict__]
+all_types = unrolling_iterable(all_types)
+
+class ArgSortCache(object):
+ built = False
+
+ def __init__(self, space):
+ if self.built:
+ return
+ self.built = True
+ cache = {}
+ for cls, it in all_types._items:
+ if it == 'complex':
+ cache[cls] = make_argsort_function(space, cls, it, 2)
+ else:
+ cache[cls] = make_argsort_function(space, cls, it)
+ self.cache = cache
+ self._lookup = specialize.memo()(lambda tp : cache[tp[0]])
+
+
class SortCache(object):
built = False
diff --git a/pypy/module/micronumpy/interp_numarray.py
b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -635,8 +635,6 @@
# modify the array in-place
if self.is_scalar():
return
- raise OperationError(space.w_NotImplementedError, space.wrap(
- "sort not implemented yet"))
return self.implementation.sort(space, w_axis, w_order)
def descr_squeeze(self, space):
diff --git a/pypy/module/micronumpy/test/test_sorting.py
b/pypy/module/micronumpy/test/test_sorting.py
--- a/pypy/module/micronumpy/test/test_sorting.py
+++ b/pypy/module/micronumpy/test/test_sorting.py
@@ -62,17 +62,30 @@
def test_sort_dtypes(self):
from numpypy import array, arange
- nnp = self.non_native_prefix
for dtype in ['int', 'float', 'int16', 'float32', 'uint64',
- nnp + 'i2', complex]:
+ 'i2', complex]:
a = array([6, 4, -1, 3, 8, 3, 256+20, 100, 101], dtype=dtype)
+ b = array([-1, 3, 3, 4, 6, 8, 100, 101, 256+20], dtype=dtype)
c = a.copy()
a.sort()
- assert (a == [-1, 3, 3, 4, 6, 8, 100, 101, 256+20]).all(), \
+ assert (a == b).all(), \
'a,orig,dtype %r,%r,%r' % (a,c,dtype)
- a = arange(100)
+ a = arange(100)
+ c = a.copy()
+ a.sort()
+ assert (a == c).all()
+
+ def test_sort_dtypesi_nonnative(self):
+ from numpypy import array
+ nnp = self.non_native_prefix
+ for dtype in [ nnp + 'i2']:
+ a = array([6, 4, -1, 3, 8, 3, 256+20, 100, 101], dtype=dtype)
+ b = array([-1, 3, 3, 4, 6, 8, 100, 101, 256+20], dtype=dtype)
c = a.copy()
- assert (a.sort() == c).all()
+ exc = raises(NotImplementedError, a.sort)
+ assert exc.value[0].find('supported') >= 0
+ #assert (a == b).all(), \
+ # 'a,orig,dtype %r,%r,%r' % (a,c,dtype)
# tests from numpy/tests/test_multiarray.py
@@ -286,8 +299,6 @@
dtype=mydtype)).all()
-
-
# tests from numpy/tests/test_regression.py
def test_sort_bigendian(self):
from numpypy import array, dtype
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit