Author: Maciej Fijalkowski <[email protected]>
Branch: missing-ndarray-attributes
Changeset: r58480:6c4a7b843097
Date: 2012-10-26 20:21 +0200
http://bitbucket.org/pypy/pypy/changeset/6c4a7b843097/
Log: progress on sorting
diff --git a/pypy/module/micronumpy/arrayimpl/concrete.py
b/pypy/module/micronumpy/arrayimpl/concrete.py
--- a/pypy/module/micronumpy/arrayimpl/concrete.py
+++ b/pypy/module/micronumpy/arrayimpl/concrete.py
@@ -10,7 +10,7 @@
from pypy.rpython.lltypesystem import rffi, lltype
from pypy.rlib import jit
from pypy.rlib.rawstorage import free_raw_storage
-from pypy.module.micronumpy.arrayimpl.sort import sort_array
+from pypy.module.micronumpy.arrayimpl.sort import argsort_array
class ConcreteArrayIterator(base.BaseArrayIterator):
def __init__(self, array):
@@ -404,8 +404,8 @@
self.order)
return SliceArray(0, strides, backstrides, new_shape, self)
- def argsort(self, space):
- return sort_array(self, space)
+ def argsort(self, space, w_axis):
+ return argsort_array(self, space, w_axis)
class SliceArray(BaseConcreteArray):
def __init__(self, start, strides, backstrides, shape, parent, dtype=None):
diff --git a/pypy/module/micronumpy/arrayimpl/scalar.py
b/pypy/module/micronumpy/arrayimpl/scalar.py
--- a/pypy/module/micronumpy/arrayimpl/scalar.py
+++ b/pypy/module/micronumpy/arrayimpl/scalar.py
@@ -96,5 +96,5 @@
raise OperationError(space.w_ValueError,
space.wrap("scalars have no address"))
- def argsort(self, space):
+ def argsort(self, space, w_axis):
return space.wrap(0)
diff --git a/pypy/module/micronumpy/arrayimpl/sort.py
b/pypy/module/micronumpy/arrayimpl/sort.py
--- a/pypy/module/micronumpy/arrayimpl/sort.py
+++ b/pypy/module/micronumpy/arrayimpl/sort.py
@@ -56,21 +56,32 @@
return ArgArrayRepresentation, ArgSort
-def sort_array(arr, space):
+def argsort_array(arr, space, w_axis):
itemtype = arr.dtype.itemtype
if (not isinstance(itemtype, types.Float) and
not isinstance(itemtype, types.Integer)):
+ # XXX this should probably be changed
raise OperationError(space.w_NotImplementedError,
space.wrap("sorting of non-numeric types is not implemented"))
+ if w_axis is space.w_None:
+ arr = arr.reshape(space, [arr.get_size()])
+ axis = 0
+ elif w_axis is None:
+ axis = -1
+ else:
+ axis = space.int_w(w_axis)
itemsize = itemtype.get_element_size()
# create array of indexes
dtype = interp_dtype.get_dtype_cache(space).w_longdtype
- indexes = W_NDimArray.from_shape([arr.get_size()], dtype)
- storage = indexes.implementation.get_storage()
- for i in range(arr.get_size()):
- raw_storage_setitem(storage, i * INT_SIZE, i)
- Repr, Sort = make_sort_classes(space, itemtype)
- r = Repr(itemsize, arr.get_size(), arr.get_storage(),
- indexes.implementation.get_storage())
- Sort(r).sort()
+ indexes = W_NDimArray.from_shape(arr.get_shape(), dtype)
+ if len(arr.get_shape()) == 1:
+ storage = indexes.implementation.get_storage()
+ for i in range(arr.get_size()):
+ raw_storage_setitem(storage, i * INT_SIZE, i)
+ Repr, Sort = make_sort_classes(space, itemtype)
+ r = Repr(itemsize, arr.get_size(), arr.get_storage(),
+ indexes.implementation.get_storage())
+ Sort(r).sort()
+ else:
+ xxx
return indexes
diff --git a/pypy/module/micronumpy/interp_numarray.py
b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -366,7 +366,7 @@
# happily ignore the kind
# create a contiguous copy of the array
contig = self.descr_copy(space)
- return contig.implementation.argsort(space)
+ return contig.implementation.argsort(space, w_axis)
def descr_astype(self, space, w_type):
raise OperationError(space.w_NotImplementedError, space.wrap(
diff --git a/pypy/module/micronumpy/test/test_numarray.py
b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -2141,6 +2141,22 @@
assert array([1, 2, 3], '<i2')[::2].tostring() == '\x01\x00\x03\x00'
assert array([1, 2, 3], '>i2')[::2].tostring() == '\x00\x01\x00\x03'
+ def test_argsort(self):
+ from _numpypy import array
+ a = array([[4, 2], [1, 3]])
+ assert (a.argsort() == [[1, 0], [0, 1]]).all()
+
+ def test_argsort_axis(self):
+ from _numpypy import array
+ a = array([[4, 2], [1, 3]])
+ assert (a.argsort(axis=None) == [2, 1, 3, 0]).all()
+ assert (a.argsort(axis=-1) == [[1, 0], [0, 1]]).all()
+ assert (a.argsort(axis=0) == [[1, 0], [0, 1]]).all()
+ assert (a.argsort(axis=1) == [[1, 0], [0, 1]]).all()
+ a = array([[3, 2, 1], [1, 2, 3]])
+ assert (a.argsort(axis=0) == [[1, 0, 0], [0, 1, 1]]).all()
+ assert (a.argsort(axis=1) == [[2, 1, 0], [0, 1, 2]]).all()
+
class AppTestRanges(BaseNumpyAppTest):
def test_arange(self):
from _numpypy import arange, array, dtype
_______________________________________________
pypy-commit mailing list
[email protected]
http://mail.python.org/mailman/listinfo/pypy-commit