Author: mattip <[email protected]>
Branch: nupypy-axis-arg-check
Changeset: r55699:f9ac9d6e6db8
Date: 2012-06-16 23:07 +0300
http://bitbucket.org/pypy/pypy/changeset/f9ac9d6e6db8/

Log:    passes all tests, ready for review

diff --git a/pypy/module/micronumpy/interp_numarray.py 
b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -15,7 +15,6 @@
 from pypy.rlib.rstring import StringBuilder
 from pypy.rpython.lltypesystem import lltype, rffi
 from pypy.tool.sourcetools import func_with_new_name
-from pypy.rlib.rarithmetic import maxint
 
 
 count_driver = jit.JitDriver(
@@ -157,18 +156,6 @@
 
     def _reduce_ufunc_impl(ufunc_name, promote_to_largest=False):
         def impl(self, space, w_axis=None, w_out=None):
-            if space.is_w(w_axis, space.w_None):
-                axis = maxint
-            else:
-                axis = space.int_w(w_axis)
-                shapelen = len(self.shape)
-                if axis < -shapelen or axis>= shapelen:
-                    raise operationerrfmt(space.w_ValueError,
-                        "axis entry %d is out of bounds [%d, %d)", axis,
-                        -shapelen, shapelen)
-                if axis < 0:
-                    axis += shapelen
-
             if space.is_w(w_out, space.w_None) or not w_out:
                 out = None
             elif not isinstance(w_out, BaseArray):
@@ -177,7 +164,7 @@
             else:
                 out = w_out
             return getattr(interp_ufuncs.get(space), ufunc_name).reduce(space,
-                                        self, True, promote_to_largest, axis,
+                                        self, True, promote_to_largest, w_axis,
                                                                    False, out)
         return func_with_new_name(impl, "reduce_%s_impl" % ufunc_name)
 
@@ -1326,22 +1313,23 @@
         raise OperationError(space.w_NotImplementedError, 
space.wrap("unsupported"))
     if space.is_w(w_axis, space.w_None):
         return space.wrap(support.product(arr.shape))
+    shapelen = len(arr.shape)
     if space.isinstance_w(w_axis, space.w_int):
         axis = space.int_w(w_axis)
-        if axis < -arr.shapelen or axis>= arr.shapelen:
+        if axis < -shapelen or axis>= shapelen:
             raise operationerrfmt(space.w_ValueError,
                 "axis entry %d is out of bounds [%d, %d)", axis,
-                -arr.shapelen, arr.shapelen)
+                -shapelen, shapelen)
         return space.wrap(arr.shape[axis])    
     # numpy as of June 2012 does not implement this 
     s = 1
     elems = space.fixedview(w_axis)
     for w_elem in elems:
         axis = space.int_w(w_elem)
-        if axis < -arr.shapelen or axis>= arr.shapelen:
+        if axis < -shapelen or axis>= shapelen:
             raise operationerrfmt(space.w_ValueError,
                 "axis entry %d is out of bounds [%d, %d)", axis,
-                -arr.shapelen, arr.shapelen)
+                -shapelen, shapelen)
         s *= arr.shape[axis]
     return space.wrap(s)
 
diff --git a/pypy/module/micronumpy/interp_support.py 
b/pypy/module/micronumpy/interp_support.py
--- a/pypy/module/micronumpy/interp_support.py
+++ b/pypy/module/micronumpy/interp_support.py
@@ -4,6 +4,7 @@
 from pypy.module.micronumpy import interp_dtype
 from pypy.objspace.std.strutil import strip_spaces
 from pypy.rlib import jit
+from pypy.rlib.rarithmetic import maxint
 
 FLOAT_SIZE = rffi.sizeof(lltype.Float)
 
@@ -103,3 +104,16 @@
         return _fromstring_bin(space, s, count, length, dtype)
     else:
         return _fromstring_text(space, s, count, sep, length, dtype)
+
+def unwrap_axis_arg(space, shapelen, w_axis):
+    if space.is_w(w_axis, space.w_None) or not w_axis:
+        axis = maxint
+    else:
+        axis = space.int_w(w_axis)
+        if axis < -shapelen or axis>= shapelen:
+            raise operationerrfmt(space.w_ValueError,
+                "axis entry %d is out of bounds [%d, %d)", axis,
+                -shapelen, shapelen)
+        if axis < 0:
+            axis += shapelen
+    return axis
diff --git a/pypy/module/micronumpy/interp_ufuncs.py 
b/pypy/module/micronumpy/interp_ufuncs.py
--- a/pypy/module/micronumpy/interp_ufuncs.py
+++ b/pypy/module/micronumpy/interp_ufuncs.py
@@ -2,11 +2,11 @@
 from pypy.interpreter.error import OperationError, operationerrfmt
 from pypy.interpreter.gateway import interp2app, unwrap_spec, NoneNotWrapped
 from pypy.interpreter.typedef import TypeDef, GetSetProperty, 
interp_attrproperty
-from pypy.module.micronumpy import interp_boxes, interp_dtype, support, loop
+from pypy.module.micronumpy import interp_boxes, interp_dtype, loop
 from pypy.rlib import jit
 from pypy.rlib.rarithmetic import LONG_BIT
 from pypy.tool.sourcetools import func_with_new_name
-from pypy.rlib.rarithmetic import maxint
+from pypy.module.micronumpy.interp_support import unwrap_axis_arg
 
 class W_Ufunc(Wrappable):
     _attrs_ = ["name", "promote_to_float", "promote_bools", "identity"]
@@ -121,18 +121,7 @@
         """
         from pypy.module.micronumpy.interp_numarray import BaseArray
         if w_axis is None:
-            axis = 0
-        elif space.is_w(w_axis, space.w_None):
-            axis = maxint
-        else:
-            axis = space.int_w(w_axis)
-            shapelen = len(self.shape)
-            if axis < -shapelen or axis>= shapelen:
-                raise operationerrfmt(space.w_ValueError,
-                    "axis entry %d is out of bounds [%d, %d)", axis,
-                    -shapelen, shapelen)
-            if axis < 0:
-                axis += shapelen
+            w_axis = space.wrap(0)
         if space.is_w(w_out, space.w_None):
             out = None
         elif not isinstance(w_out, BaseArray):
@@ -140,9 +129,9 @@
                                                 'output must be an array'))
         else:
             out = w_out
-        return self.reduce(space, w_obj, False, False, axis, keepdims, out)
+        return self.reduce(space, w_obj, False, False, w_axis, keepdims, out)
 
-    def reduce(self, space, w_obj, multidim, promote_to_largest, axis,
+    def reduce(self, space, w_obj, multidim, promote_to_largest, w_axis,
                keepdims=False, out=None):
         from pypy.module.micronumpy.interp_numarray import convert_to_array, \
                                              Scalar, ReduceArray, W_NDimArray
@@ -150,11 +139,12 @@
             raise OperationError(space.w_ValueError, space.wrap("reduce only "
                 "supported for binary functions"))
         assert isinstance(self, W_Ufunc2)
-        assert axis>=0
         obj = convert_to_array(space, w_obj)
         if isinstance(obj, Scalar):
             raise OperationError(space.w_TypeError, space.wrap("cannot reduce "
                 "on a scalar"))
+        axis = unwrap_axis_arg(space, len(obj.shape), w_axis)    
+        assert axis>=0
         size = obj.size
         if self.comparison_func:
             dtype = interp_dtype.get_dtype_cache(space).w_booldtype
diff --git a/pypy/module/micronumpy/test/test_ufuncs.py 
b/pypy/module/micronumpy/test/test_ufuncs.py
--- a/pypy/module/micronumpy/test/test_ufuncs.py
+++ b/pypy/module/micronumpy/test/test_ufuncs.py
@@ -637,7 +637,7 @@
         assert count_reduce_items(a) == 24
         assert count_reduce_items(a, 1) == 3
         assert count_reduce_items(a, (1, 2)) == 3 * 4
-        raises(ValueError, count_reduce_items, a, -3)
+        raises(ValueError, count_reduce_items, a, -4)
         raises(ValueError, count_reduce_items, a, (0, 2, -4))
 
     def test_true_divide(self):
_______________________________________________
pypy-commit mailing list
[email protected]
http://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to