Author: Ronan Lamy <ronan.l...@gmail.com> Branch: ufunc-reduce Changeset: r78749:8586b501eab6 Date: 2015-08-02 19:30 +0100 http://bitbucket.org/pypy/pypy/changeset/8586b501eab6/
Log: Clean up shape handling in ufunc.accumulate() diff --git a/pypy/module/micronumpy/test/test_ufuncs.py b/pypy/module/micronumpy/test/test_ufuncs.py --- a/pypy/module/micronumpy/test/test_ufuncs.py +++ b/pypy/module/micronumpy/test/test_ufuncs.py @@ -1338,6 +1338,26 @@ assert subtract.accumulate([True]*200).dtype == dtype('bool') assert divide.accumulate([True]*200).dtype == dtype('int8') + def test_accumulate_shapes(self): + import numpy as np + a = np.arange(6).reshape(2, 1, 3) + assert np.add.accumulate(a).shape == (2, 1, 3) + raises(ValueError, "np.add.accumulate(a, out=np.zeros((3, 1, 3)))") + raises(ValueError, "np.add.accumulate(a, out=np.zeros((2, 3)))") + raises(ValueError, "np.add.accumulate(a, out=np.zeros((2, 3, 1)))") + b = np.zeros((2, 1, 3)) + np.add.accumulate(a, out=b, axis=2) + assert b[0, 0, 2] == 3 + + def test_accumulate_shapes_2(self): + import sys + if '__pypy__' not in sys.builtin_module_names: + skip('PyPy-specific behavior in np.ufunc.accumulate') + import numpy as np + a = np.arange(6).reshape(2, 1, 3) + raises(ValueError, "np.add.accumulate(a, out=np.zeros((2, 1, 3, 2)))") + + def test_noncommutative_reduce_accumulate(self): import numpy as np tosubtract = np.arange(5) diff --git a/pypy/module/micronumpy/ufuncs.py b/pypy/module/micronumpy/ufuncs.py --- a/pypy/module/micronumpy/ufuncs.py +++ b/pypy/module/micronumpy/ufuncs.py @@ -247,33 +247,25 @@ axis = axes[0] assert axis >= 0 dtype = self.find_binop_type(space, dtype) - call__array_wrap__ = True + shape = obj_shape[:] + if out: + # There appears to be a lot of accidental complexity in what + # shapes cnumpy allows for out. + # We simply require out.shape == obj.shape + if out.get_shape() != obj_shape: + raise oefmt(space.w_ValueError, + "output parameter shape mismatch, expecting " + "[%s], got [%s]", + ",".join([str(x) for x in shape]), + ",".join([str(x) for x in out.get_shape()]), + ) + dtype = out.get_dtype() + call__array_wrap__ = False + else: + out = W_NDimArray.from_shape(space, shape, dtype, + w_instance=obj) + call__array_wrap__ = True if shapelen > 1: - shape = obj_shape[:] - if out: - # Test for shape agreement - # XXX maybe we need to do broadcasting here, although I must - # say I don't understand the details for axis reduce - if out.ndims() > len(shape): - raise oefmt(space.w_ValueError, - "output parameter for reduction operation %s " - "has too many dimensions", self.name) - elif out.ndims() < len(shape): - raise oefmt(space.w_ValueError, - "output parameter for reduction operation %s " - "does not have enough dimensions", self.name) - elif out.get_shape() != shape: - raise oefmt(space.w_ValueError, - "output parameter shape mismatch, expecting " - "[%s], got [%s]", - ",".join([str(x) for x in shape]), - ",".join([str(x) for x in out.get_shape()]), - ) - call__array_wrap__ = False - dtype = out.get_dtype() - else: - out = W_NDimArray.from_shape(space, shape, dtype, - w_instance=obj) if obj.get_size() == 0: if self.identity is not None: out.fill(space, self.identity.convert_to(space, dtype)) @@ -281,14 +273,6 @@ loop.do_accumulate(space, self.func, obj, dtype, axis, out, self.identity) else: - if out: - call__array_wrap__ = False - if out.get_shape() != [obj.get_size()]: - raise OperationError(space.w_ValueError, space.wrap( - "out of incompatible size")) - else: - out = W_NDimArray.from_shape(space, [obj.get_size()], dtype, - w_instance=obj) loop.compute_reduce_cumulative(space, obj, out, dtype, self.func, self.identity) if call__array_wrap__: _______________________________________________ pypy-commit mailing list pypy-commit@python.org https://mail.python.org/mailman/listinfo/pypy-commit