Author: Maciej Fijalkowski <[email protected]>
Branch: missing-ndarray-attributes
Changeset: r58480:6c4a7b843097
Date: 2012-10-26 20:21 +0200
http://bitbucket.org/pypy/pypy/changeset/6c4a7b843097/

Log:    progress on sorting

diff --git a/pypy/module/micronumpy/arrayimpl/concrete.py 
b/pypy/module/micronumpy/arrayimpl/concrete.py
--- a/pypy/module/micronumpy/arrayimpl/concrete.py
+++ b/pypy/module/micronumpy/arrayimpl/concrete.py
@@ -10,7 +10,7 @@
 from pypy.rpython.lltypesystem import rffi, lltype
 from pypy.rlib import jit
 from pypy.rlib.rawstorage import free_raw_storage
-from pypy.module.micronumpy.arrayimpl.sort import sort_array
+from pypy.module.micronumpy.arrayimpl.sort import argsort_array
 
 class ConcreteArrayIterator(base.BaseArrayIterator):
     def __init__(self, array):
@@ -404,8 +404,8 @@
                                                     self.order)
         return SliceArray(0, strides, backstrides, new_shape, self)
 
-    def argsort(self, space):
-        return sort_array(self, space)
+    def argsort(self, space, w_axis):
+        return argsort_array(self, space, w_axis)
 
 class SliceArray(BaseConcreteArray):
     def __init__(self, start, strides, backstrides, shape, parent, dtype=None):
diff --git a/pypy/module/micronumpy/arrayimpl/scalar.py 
b/pypy/module/micronumpy/arrayimpl/scalar.py
--- a/pypy/module/micronumpy/arrayimpl/scalar.py
+++ b/pypy/module/micronumpy/arrayimpl/scalar.py
@@ -96,5 +96,5 @@
         raise OperationError(space.w_ValueError,
                              space.wrap("scalars have no address"))
 
-    def argsort(self, space):
+    def argsort(self, space, w_axis):
         return space.wrap(0)
diff --git a/pypy/module/micronumpy/arrayimpl/sort.py 
b/pypy/module/micronumpy/arrayimpl/sort.py
--- a/pypy/module/micronumpy/arrayimpl/sort.py
+++ b/pypy/module/micronumpy/arrayimpl/sort.py
@@ -56,21 +56,32 @@
 
     return ArgArrayRepresentation, ArgSort
 
-def sort_array(arr, space):
+def argsort_array(arr, space, w_axis):
     itemtype = arr.dtype.itemtype
     if (not isinstance(itemtype, types.Float) and
         not isinstance(itemtype, types.Integer)):
+        # XXX this should probably be changed
         raise OperationError(space.w_NotImplementedError,
            space.wrap("sorting of non-numeric types is not implemented"))
+    if w_axis is space.w_None:
+        arr = arr.reshape(space, [arr.get_size()])
+        axis = 0
+    elif w_axis is None:
+        axis = -1
+    else:
+        axis = space.int_w(w_axis)
     itemsize = itemtype.get_element_size()
     # create array of indexes
     dtype = interp_dtype.get_dtype_cache(space).w_longdtype
-    indexes = W_NDimArray.from_shape([arr.get_size()], dtype)
-    storage = indexes.implementation.get_storage()
-    for i in range(arr.get_size()):
-        raw_storage_setitem(storage, i * INT_SIZE, i)
-    Repr, Sort = make_sort_classes(space, itemtype)
-    r = Repr(itemsize, arr.get_size(), arr.get_storage(),
-             indexes.implementation.get_storage())
-    Sort(r).sort()
+    indexes = W_NDimArray.from_shape(arr.get_shape(), dtype)
+    if len(arr.get_shape()) == 1:
+        storage = indexes.implementation.get_storage()
+        for i in range(arr.get_size()):
+            raw_storage_setitem(storage, i * INT_SIZE, i)
+        Repr, Sort = make_sort_classes(space, itemtype)
+        r = Repr(itemsize, arr.get_size(), arr.get_storage(),
+                 indexes.implementation.get_storage())
+        Sort(r).sort()
+    else:
+        xxx
     return indexes
diff --git a/pypy/module/micronumpy/interp_numarray.py 
b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -366,7 +366,7 @@
         # happily ignore the kind
         # create a contiguous copy of the array
         contig = self.descr_copy(space)
-        return contig.implementation.argsort(space)
+        return contig.implementation.argsort(space, w_axis)
 
     def descr_astype(self, space, w_type):
         raise OperationError(space.w_NotImplementedError, space.wrap(
diff --git a/pypy/module/micronumpy/test/test_numarray.py 
b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -2141,6 +2141,22 @@
         assert array([1, 2, 3], '<i2')[::2].tostring() == '\x01\x00\x03\x00'
         assert array([1, 2, 3], '>i2')[::2].tostring() == '\x00\x01\x00\x03'
 
+    def test_argsort(self):
+        from _numpypy import array
+        a = array([[4, 2], [1, 3]])
+        assert (a.argsort() == [[1, 0], [0, 1]]).all()
+
+    def test_argsort_axis(self):
+        from _numpypy import array
+        a = array([[4, 2], [1, 3]]) 
+        assert (a.argsort(axis=None) == [2, 1, 3, 0]).all()
+        assert (a.argsort(axis=-1) == [[1, 0], [0, 1]]).all()
+        assert (a.argsort(axis=0) == [[1, 0], [0, 1]]).all()
+        assert (a.argsort(axis=1) == [[1, 0], [0, 1]]).all()
+        a = array([[3, 2, 1], [1, 2, 3]])
+        assert (a.argsort(axis=0) == [[1, 0, 0], [0, 1, 1]]).all()
+        assert (a.argsort(axis=1) == [[2, 1, 0], [0, 1, 2]]).all()
+
 class AppTestRanges(BaseNumpyAppTest):
     def test_arange(self):
         from _numpypy import arange, array, dtype
_______________________________________________
pypy-commit mailing list
[email protected]
http://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to