Author: Matti Picus <[email protected]>
Branch: numpypy-argminmax
Changeset: r55665:6377829a0544
Date: 2012-06-14 22:58 +0300
http://bitbucket.org/pypy/pypy/changeset/6377829a0544/
Log: more input validity tests
diff --git a/pypy/module/micronumpy/interp_numarray.py
b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -187,11 +187,9 @@
get_printable_location=signature.new_printable_location(op_name),
name='numpy_' + op_name,
)
- def loop(self, space, axis, out):
+ def do_argminmax(self, axis, out):
if isinstance(self, Scalar):
return 0
- if axis >= len(self.shape):
- raise OperationError(space.w_ValueError, space.wrap("axis(=%d)
out of bounds" % axis))
sig = self.find_sig()
frame = sig.create_frame(self)
cur_best = sig.eval(frame, self)
@@ -223,6 +221,8 @@
axis = -1
else:
axis = space.int_w(w_axis)
+ if axis >= len(self.shape) or axis<0:
+ raise OperationError(space.w_ValueError,
space.wrap("axis(=%d) out of bounds" % axis))
if space.is_w(w_out, space.w_None) or not w_out:
out = None
elif not isinstance(w_out, BaseArray):
@@ -230,7 +230,28 @@
'output must be an array'))
else:
out = w_out
- return space.wrap(loop(self, space, axis, out))
+ shapelen = len(self.shape)
+ if axis<0:
+ shape = [1]
+ else:
+ shape = self.shape[:axis] + self.shape[axis + 1:]
+ #Test for shape agreement
+ if len(out.shape) > len(shape):
+ raise operationerrfmt(space.w_TypesError,
+ 'invalid shape for output array')
+ elif len(out.shape) < len(shape):
+ raise operationerrfmt(space.w_TypesError,
+ 'invalid shape for output array')
+ elif out.shape != shape:
+ raise operationerrfmt(space.w_TypesError,
+ 'invalid shape for output array')
+ #Test for dtype agreement, perhaps create an itermediate
+ #if out.dtype != self.dtype:
+ # raise OperationError(space.w_TypeError, space.wrap(
+ # "mismatched dtypes"))
+ return space.wrap(do_argminmax(self, ufunc_name,
+ return func_with_new_name(impl, "reduce_%s_impl" % ufunc_name)
+ return space.wrap(do_argminmax(self, axis, out))
return func_with_new_name(impl, "reduce_arg%s_impl" % op_name)
descr_argmax = _reduce_argmax_argmin_impl("max")
diff --git a/pypy/module/micronumpy/test/test_outarg.py
b/pypy/module/micronumpy/test/test_outarg.py
--- a/pypy/module/micronumpy/test/test_outarg.py
+++ b/pypy/module/micronumpy/test/test_outarg.py
@@ -125,7 +125,7 @@
"Cannot cast ufunc negative output from dtype('float64') to
dtype('int64') with casting rule 'same_kind'"
def test_argminmax(self):
- from numpypy import arange
+ from numpypy import arange, argmin
a = arange(15).reshape(5, 3)
b = arange(15).reshape(5,3)
c = a.argmax(0, out=b[1])
@@ -134,4 +134,11 @@
c = a.argmax(1, out=b[:,1])
assert (c == [2, 2, 2, 2, 2]).all()
assert (c == b[:,1]).all()
+ raises(ValueError, argmin, a, -2)
+ raises(ValueError, argmin, a, 2)
+ b = ones((5,3), dtype=float)
+ c = a.argmin(0, out=b[1])
+ assert (b[1] == [0, 0, 0]).all()
+ assert(c.dtype is b.dtype)
+
_______________________________________________
pypy-commit mailing list
[email protected]
http://mail.python.org/mailman/listinfo/pypy-commit