Author: Maciej Fijalkowski <[email protected]>
Branch: missing-ndarray-attributes
Changeset: r58649:254c84940fc7
Date: 2012-10-31 18:21 +0200
http://bitbucket.org/pypy/pypy/changeset/254c84940fc7/

Log:    simple case of diagonal

diff --git a/pypy/module/micronumpy/interp_arrayops.py 
b/pypy/module/micronumpy/interp_arrayops.py
--- a/pypy/module/micronumpy/interp_arrayops.py
+++ b/pypy/module/micronumpy/interp_arrayops.py
@@ -170,3 +170,15 @@
                              space.wrap("mode %s not known" % (mode,)))
     loop.choose(space, arr, choices, shape, dtype, out, MODES[mode])
     return out
+
+def diagonal(space, arr, offset, axis1, axis2):
+    shape = arr.get_shape()
+    size = min(shape[axis1], shape[axis2] - offset)
+    dtype = arr.dtype
+    if len(shape) == 2:
+        # simple case
+        out = W_NDimArray.from_shape([size], dtype)
+        loop.diagonal_simple(space, arr, out, offset, axis1, axis2, size)
+        return out
+    else:
+        xxx
diff --git a/pypy/module/micronumpy/interp_numarray.py 
b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -429,10 +429,19 @@
     def descr_get_data(self, space):
         return self.implementation.get_buffer(space)
 
-    def descr_diagonal(self, space, w_offset=0, w_axis1=0, w_axis2=1): 
-        raise OperationError(space.w_NotImplementedError, space.wrap(
-            "diagonal not implemented yet"))
-
+    @unwrap_spec(offset=int, axis1=int, axis2=int)
+    def descr_diagonal(self, space, offset=0, axis1=0, axis2=1):
+        if len(self.get_shape()) < 2:
+            raise OperationError(space.w_ValueError, space.wrap(
+                "need at least 2 dimensions for diagonal"))
+        if (axis1 < 0 or axis2 < 0 or axis1 >= len(self.get_shape()) or
+            axis2 >= len(self.get_shape())):
+            raise operationerrfmt(space.w_ValueError,
+                 "axis1(=%d) and axis2(=%d) must be withing range (ndim=%d)",
+                                  axis1, axis2, len(self.get_shape()))
+        return interp_arrayops.diagonal(space, self.implementation, offset,
+                                        axis1, axis2)
+    
     def descr_dump(self, space, w_file):
         raise OperationError(space.w_NotImplementedError, space.wrap(
             "dump not implemented yet"))
@@ -801,6 +810,7 @@
     choose   = interp2app(W_NDimArray.descr_choose),
     clip     = interp2app(W_NDimArray.descr_clip),
     data     = GetSetProperty(W_NDimArray.descr_get_data),
+    diagonal = interp2app(W_NDimArray.descr_diagonal),
 
     ctypes = GetSetProperty(W_NDimArray.descr_get_ctypes), # XXX unimplemented
 
diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -580,3 +580,21 @@
         max_iter.next()
         out_iter.next()
         min_iter.next()
+
+diagonal_simple_driver = jit.JitDriver(greens = ['axis1', 'axis2'],
+                                       reds = ['i', 'offset', 'out_iter',
+                                               'arr'])
+
+def diagonal_simple(space, arr, out, offset, axis1, axis2, size):
+    out_iter = out.create_iter()
+    i = 0
+    index = [0] * 2
+    while i < size:
+        diagonal_simple_driver.jit_merge_point(axis1=axis1, axis2=axis2,
+                                               out_iter=out_iter, 
offset=offset,
+                                               i=i, arr=arr)
+        index[axis1] = i
+        index[axis2] = i + offset
+        out_iter.setitem(arr.getitem_index(space, index))
+        i += 1
+        out_iter.next()
diff --git a/pypy/module/micronumpy/test/test_numarray.py 
b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -2110,6 +2110,13 @@
         assert (a.cumsum(1) == [[1, 2], [2, 4], [3, 7]]).all()
         assert (a.cumsum(0) == [[1, 1], [3, 3], [6, 7]]).all()
 
+    def test_diagonal(self):
+        from _numpypy import array
+        a = array([[1, 2], [3, 4], [5, 6]])
+        raises(ValueError, 'array([1, 2]).diagonal()')
+        assert (a.diagonal() == [1, 4]).all()
+        assert (a.diagonal(1) == [2]).all()
+
 class AppTestSupport(BaseNumpyAppTest):
     def setup_class(cls):
         import struct
_______________________________________________
pypy-commit mailing list
[email protected]
http://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to