Author: Matti Picus <matti.pi...@gmail.com> Branch: numpypy-argminmax Changeset: r55665:6377829a0544 Date: 2012-06-14 22:58 +0300 http://bitbucket.org/pypy/pypy/changeset/6377829a0544/
Log: more input validity tests diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py --- a/pypy/module/micronumpy/interp_numarray.py +++ b/pypy/module/micronumpy/interp_numarray.py @@ -187,11 +187,9 @@ get_printable_location=signature.new_printable_location(op_name), name='numpy_' + op_name, ) - def loop(self, space, axis, out): + def do_argminmax(self, axis, out): if isinstance(self, Scalar): return 0 - if axis >= len(self.shape): - raise OperationError(space.w_ValueError, space.wrap("axis(=%d) out of bounds" % axis)) sig = self.find_sig() frame = sig.create_frame(self) cur_best = sig.eval(frame, self) @@ -223,6 +221,8 @@ axis = -1 else: axis = space.int_w(w_axis) + if axis >= len(self.shape) or axis<0: + raise OperationError(space.w_ValueError, space.wrap("axis(=%d) out of bounds" % axis)) if space.is_w(w_out, space.w_None) or not w_out: out = None elif not isinstance(w_out, BaseArray): @@ -230,7 +230,28 @@ 'output must be an array')) else: out = w_out - return space.wrap(loop(self, space, axis, out)) + shapelen = len(self.shape) + if axis<0: + shape = [1] + else: + shape = self.shape[:axis] + self.shape[axis + 1:] + #Test for shape agreement + if len(out.shape) > len(shape): + raise operationerrfmt(space.w_TypesError, + 'invalid shape for output array') + elif len(out.shape) < len(shape): + raise operationerrfmt(space.w_TypesError, + 'invalid shape for output array') + elif out.shape != shape: + raise operationerrfmt(space.w_TypesError, + 'invalid shape for output array') + #Test for dtype agreement, perhaps create an itermediate + #if out.dtype != self.dtype: + # raise OperationError(space.w_TypeError, space.wrap( + # "mismatched dtypes")) + return space.wrap(do_argminmax(self, ufunc_name, + return func_with_new_name(impl, "reduce_%s_impl" % ufunc_name) + return space.wrap(do_argminmax(self, axis, out)) return func_with_new_name(impl, "reduce_arg%s_impl" % op_name) descr_argmax = _reduce_argmax_argmin_impl("max") diff --git a/pypy/module/micronumpy/test/test_outarg.py b/pypy/module/micronumpy/test/test_outarg.py --- a/pypy/module/micronumpy/test/test_outarg.py +++ b/pypy/module/micronumpy/test/test_outarg.py @@ -125,7 +125,7 @@ "Cannot cast ufunc negative output from dtype('float64') to dtype('int64') with casting rule 'same_kind'" def test_argminmax(self): - from numpypy import arange + from numpypy import arange, argmin a = arange(15).reshape(5, 3) b = arange(15).reshape(5,3) c = a.argmax(0, out=b[1]) @@ -134,4 +134,11 @@ c = a.argmax(1, out=b[:,1]) assert (c == [2, 2, 2, 2, 2]).all() assert (c == b[:,1]).all() + raises(ValueError, argmin, a, -2) + raises(ValueError, argmin, a, 2) + b = ones((5,3), dtype=float) + c = a.argmin(0, out=b[1]) + assert (b[1] == [0, 0, 0]).all() + assert(c.dtype is b.dtype) + _______________________________________________ pypy-commit mailing list pypy-commit@python.org http://mail.python.org/mailman/listinfo/pypy-commit