Author: mattip <matti.pi...@gmail.com>
Branch: numpy_broadcast
Changeset: r83712:16f4a95d75ee
Date: 2016-04-17 16:33 +0300
http://bitbucket.org/pypy/pypy/changeset/16f4a95d75ee/

Log:    test, implement overflow checking; small cleanups

diff --git a/pypy/module/micronumpy/broadcast.py 
b/pypy/module/micronumpy/broadcast.py
--- a/pypy/module/micronumpy/broadcast.py
+++ b/pypy/module/micronumpy/broadcast.py
@@ -1,6 +1,6 @@
 import pypy.module.micronumpy.constants as NPY
 from nditer import ConcreteIter, parse_op_flag, parse_op_arg
-from pypy.interpreter.error import OperationError
+from pypy.interpreter.error import OperationError, oefmt
 from pypy.interpreter.gateway import interp2app
 from pypy.interpreter.typedef import TypeDef, GetSetProperty
 from pypy.module.micronumpy import support
@@ -8,6 +8,8 @@
 from rpython.rlib import jit
 from strides import calculate_broadcast_strides, shape_agreement_multiple
 
+def descr_new_broadcast(space, w_subtype, __args__):
+    return W_Broadcast(space, __args__.arguments_w)
 
 class W_Broadcast(W_NumpyObject):
     """
@@ -15,15 +17,11 @@
     This class is a simplified version of nditer.W_NDIter with fixed iteration 
for broadcasted arrays.
     """
 
-    @staticmethod
-    def descr_new_broadcast(space, w_subtype, __args__):
-        return W_Broadcast(space, __args__.arguments_w)
-
     def __init__(self, space, args):
         num_args = len(args)
         if not (2 <= num_args <= NPY.MAXARGS):
-            raise OperationError(space.w_ValueError,
-                                 space.wrap("Need at least two and fewer than 
(%d) array objects." % NPY.MAXARGS))
+            raise oefmt(space.w_ValueError,
+                                 "Need at least two and fewer than (%d) array 
objects.", NPY.MAXARGS)
 
         self.seq = [convert_to_array(space, w_elem)
                     for w_elem in args]
@@ -37,7 +35,10 @@
         self.iters = []
         self.index = 0
 
-        self.size = support.product(self.shape)
+        try:
+            self.size = support.product_check(self.shape)
+        except OverflowError as e:
+            raise oefmt(space.w_ValueError, "broadcast dimensions too large.")
         for i in range(len(self.seq)):
             it = self.get_iter(space, i)
             it.contiguous = False
@@ -99,7 +100,7 @@
 
 
 W_Broadcast.typedef = TypeDef("numpy.broadcast",
-                              
__new__=interp2app(W_Broadcast.descr_new_broadcast),
+                              __new__=interp2app(descr_new_broadcast),
                               __iter__=interp2app(W_Broadcast.descr_iter),
                               next=interp2app(W_Broadcast.descr_next),
                               
shape=GetSetProperty(W_Broadcast.descr_get_shape),
diff --git a/pypy/module/micronumpy/test/test_broadcast.py 
b/pypy/module/micronumpy/test/test_broadcast.py
--- a/pypy/module/micronumpy/test/test_broadcast.py
+++ b/pypy/module/micronumpy/test/test_broadcast.py
@@ -55,11 +55,19 @@
         assert b == [(1, 4), (2, 5), (3, 6)]
         assert b[0][0].dtype == x.dtype
 
-    def test_broadcast_linear_unequal(self):
+    def test_broadcast_failures(self):
         import numpy as np
+        import sys
         x = np.array([1, 2, 3])
         y = np.array([4, 5])
         raises(ValueError, np.broadcast, x, y)
+        a = np.empty(2**16,dtype='int8')
+        a = a.reshape(-1, 1, 1, 1)
+        b = a.reshape(1, -1, 1, 1)
+        c = a.reshape(1, 1, -1, 1)
+        d = a.reshape(1, 1, 1, -1)
+        exc = raises(ValueError, np.broadcast, a, b, c, d)
+        assert exc.value[0] == ('broadcast dimensions too large.')
 
     def test_broadcast_3_args(self):
         import numpy as np
@@ -82,7 +90,8 @@
         for j in range(35):
             arrs = [arr] * j
             if j < 2 or j > 32:
-                raises(ValueError, np.broadcast, *arrs)
+                exc = raises(ValueError, np.broadcast, *arrs)
+                assert exc.value[0] == ('Need at least two and fewer than (32) 
array objects.')
             else:
                 mit = np.broadcast(*arrs)
                 assert mit.numiter == j
_______________________________________________
pypy-commit mailing list
pypy-commit@python.org
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to