Author: Brian Kearns <bdkea...@gmail.com> Branch: Changeset: r67400:3fd593fe30bf Date: 2013-10-15 19:09 -0400 http://bitbucket.org/pypy/pypy/changeset/3fd593fe30bf/
Log: implement and test ndarray.trace() diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py --- a/pypy/module/micronumpy/interp_numarray.py +++ b/pypy/module/micronumpy/interp_numarray.py @@ -550,6 +550,12 @@ return interp_arrayops.diagonal(space, self.implementation, offset, axis1, axis2) + @unwrap_spec(offset=int, axis1=int, axis2=int) + def descr_trace(self, space, offset=0, axis1=0, axis2=1, + w_dtype=None, w_out=None): + diag = self.descr_diagonal(space, offset, axis1, axis2) + return diag.descr_sum(space, w_axis=space.wrap(-1), w_dtype=w_dtype, w_out=w_out) + def descr_dump(self, space, w_file): raise OperationError(space.w_NotImplementedError, space.wrap( "dump not implemented yet")) @@ -653,11 +659,6 @@ raise OperationError(space.w_NotImplementedError, space.wrap( "tofile not implemented yet")) - def descr_trace(self, space, w_offset=0, w_axis1=0, w_axis2=1, - w_dtype=None, w_out=None): - raise OperationError(space.w_NotImplementedError, space.wrap( - "trace not implemented yet")) - def descr_view(self, space, w_dtype=None, w_type=None) : if not w_type and w_dtype: try: @@ -1153,6 +1154,7 @@ round = interp2app(W_NDimArray.descr_round), data = GetSetProperty(W_NDimArray.descr_get_data), diagonal = interp2app(W_NDimArray.descr_diagonal), + trace = interp2app(W_NDimArray.descr_trace), view = interp2app(W_NDimArray.descr_view), ctypes = GetSetProperty(W_NDimArray.descr_get_ctypes), # XXX unimplemented diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py --- a/pypy/module/micronumpy/test/test_numarray.py +++ b/pypy/module/micronumpy/test/test_numarray.py @@ -1465,6 +1465,14 @@ assert a[3].imag == -10 assert a[2].imag == -5 + def test_trace(self): + import numpypy as np + assert np.trace(np.eye(3)) == 3.0 + a = np.arange(8).reshape((2,2,2)) + assert np.array_equal(np.trace(a), [6, 8]) + a = np.arange(24).reshape((2,2,2,3)) + assert np.trace(a).shape == (2, 3) + def test_view(self): from numpypy import array, int8, int16, dtype x = array((1, 2), dtype=int8) _______________________________________________ pypy-commit mailing list pypy-commit@python.org https://mail.python.org/mailman/listinfo/pypy-commit