Author: mattip <[email protected]>
Branch: 
Changeset: r77302:30fa3802a882
Date: 2015-05-13 07:41 +0300
http://bitbucket.org/pypy/pypy/changeset/30fa3802a882/

Log:    test, fix for issue #2046

diff --git a/pypy/module/micronumpy/concrete.py 
b/pypy/module/micronumpy/concrete.py
--- a/pypy/module/micronumpy/concrete.py
+++ b/pypy/module/micronumpy/concrete.py
@@ -519,6 +519,9 @@
         return self.__class__(self.start, new_strides, new_backstrides, 
new_shape,
                           self, orig_array)
 
+    def sort(self, space, w_axis, w_order):
+        from .selection import sort_array
+        return sort_array(self, space, w_axis, w_order)
 
 class NonWritableSliceArray(SliceArray):
     def descr_setitem(self, space, orig_array, w_index, w_value):
diff --git a/pypy/module/micronumpy/selection.py 
b/pypy/module/micronumpy/selection.py
--- a/pypy/module/micronumpy/selection.py
+++ b/pypy/module/micronumpy/selection.py
@@ -120,7 +120,7 @@
     ArgSort = make_timsort_class(arg_getitem, arg_setitem, arg_length,
                                  arg_getitem_slice, arg_lt)
 
-    def argsort(arr, space, w_axis, itemsize):
+    def argsort(arr, space, w_axis):
         if w_axis is space.w_None:
             # note that it's fine ot pass None here as we're not going
             # to pass the result around (None is the link to base in slices)
@@ -138,7 +138,7 @@
             if len(arr.get_shape()) == 1:
                 for i in range(arr.get_size()):
                     raw_storage_setitem(storage, i * INT_SIZE, i)
-                r = Repr(INT_SIZE, itemsize, arr.get_size(), arr_storage,
+                r = Repr(INT_SIZE, arr.strides[0], arr.get_size(), arr_storage,
                          storage, 0, arr.start)
                 ArgSort(r).sort()
             else:
@@ -174,8 +174,7 @@
     itemtype = arr.dtype.itemtype
     for tp in all_types:
         if isinstance(itemtype, tp[0]):
-            return cache._lookup(tp)(arr, space, w_axis,
-                                     itemtype.get_element_size())
+            return cache._lookup(tp)(arr, space, w_axis)
     # XXX this should probably be changed
     raise oefmt(space.w_NotImplementedError,
                 "sorting of non-numeric types '%s' is not implemented",
@@ -272,7 +271,7 @@
     ArgSort = make_timsort_class(arg_getitem, arg_setitem, arg_length,
                                  arg_getitem_slice, arg_lt)
 
-    def sort(arr, space, w_axis, itemsize):
+    def sort(arr, space, w_axis):
         if w_axis is space.w_None:
             # note that it's fine to pass None here as we're not going
             # to pass the result around (None is the link to base in slices)
@@ -284,7 +283,7 @@
             axis = space.int_w(w_axis)
         with arr as storage:
             if len(arr.get_shape()) == 1:
-                r = Repr(itemsize, arr.get_size(), storage,
+                r = Repr(arr.strides[0], arr.get_size(), storage,
                          arr.start)
                 ArgSort(r).sort()
             else:
@@ -313,8 +312,7 @@
                     "sorting of non-native byteorder not supported yet")
     for tp in all_types:
         if isinstance(itemtype, tp[0]):
-            return cache._lookup(tp)(arr, space, w_axis,
-                                     itemtype.get_element_size())
+            return cache._lookup(tp)(arr, space, w_axis)
     # XXX this should probably be changed
     raise oefmt(space.w_NotImplementedError,
                 "sorting of non-numeric types '%s' is not implemented",
diff --git a/pypy/module/micronumpy/test/test_selection.py 
b/pypy/module/micronumpy/test/test_selection.py
--- a/pypy/module/micronumpy/test/test_selection.py
+++ b/pypy/module/micronumpy/test/test_selection.py
@@ -82,6 +82,13 @@
             #assert (a == b).all(), \
             #    'a,orig,dtype %r,%r,%r' % (a,c,dtype)
 
+    def test_sort_noncontiguous(self):
+        from numpy import array
+        x = array([[2, 10], [1, 11]])
+        assert (x[:, 0].argsort() == [1, 0]).all()
+        x[:, 0].sort()
+        assert (x == [[1, 10], [2, 11]]).all()
+
 # tests from numpy/tests/test_multiarray.py
     def test_sort_corner_cases(self):
         # test ordering for floats and complex containing nans. It is only
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to