Author: Sergey Matyunin <[email protected]>
Branch: numpy_broadcast
Changeset: r83709:bbb4848772d2
Date: 2016-04-08 10:34 +0200
http://bitbucket.org/pypy/pypy/changeset/bbb4848772d2/

Log:    Implemented W_Broadcast for numpy.broadcast

diff --git a/pypy/module/micronumpy/__init__.py 
b/pypy/module/micronumpy/__init__.py
--- a/pypy/module/micronumpy/__init__.py
+++ b/pypy/module/micronumpy/__init__.py
@@ -32,6 +32,7 @@
         'set_string_function': 'appbridge.set_string_function',
         'typeinfo': 'descriptor.get_dtype_cache(space).w_typeinfo',
         'nditer': 'nditer.W_NDIter',
+        'broadcast': 'broadcast.W_Broadcast',
 
         'set_docstring': 'support.descr_set_docstring',
         'VisibleDeprecationWarning': 'support.W_VisibleDeprecationWarning',
diff --git a/pypy/module/micronumpy/broadcast.py 
b/pypy/module/micronumpy/broadcast.py
new file mode 100644
--- /dev/null
+++ b/pypy/module/micronumpy/broadcast.py
@@ -0,0 +1,100 @@
+import operator
+
+import pypy.module.micronumpy.constants as NPY
+from nditer import ConcreteIter, parse_op_flag, parse_op_arg
+from pypy.interpreter.error import OperationError
+from pypy.interpreter.gateway import interp2app
+from pypy.interpreter.typedef import TypeDef, GetSetProperty
+from pypy.module.micronumpy.base import W_NDimArray, convert_to_array, 
W_NumpyObject
+from rpython.rlib import jit
+from strides import calculate_broadcast_strides, shape_agreement_multiple
+
+
+class W_Broadcast(W_NumpyObject):
+    """
+    Implementation of numpy.broadcast.
+    This class is a simplified version of nditer.W_NDIter with fixed iteration 
for broadcasted arrays.
+    """
+
+    @staticmethod
+    def descr_new_broadcast(space, w_subtype, __args__):
+        return W_Broadcast(space, __args__.arguments_w)
+
+    def __init__(self, space, w_args):
+        self.seq = [convert_to_array(space, w_elem)
+                    for w_elem in w_args]
+
+        self.op_flags = parse_op_arg(space, 'op_flags', space.w_None,
+                                     len(self.seq), parse_op_flag)
+
+        self.shape = tuple(shape_agreement_multiple(space, self.seq, 
shape=None))
+        self.order = NPY.CORDER
+
+        self.iters = []
+        self.index = 0
+        self.size = reduce(operator.mul, self.shape, 1)
+        for i in range(len(self.seq)):
+            it = self.get_iter(space, i)
+            it.contiguous = False
+            self.iters.append((it, it.reset()))
+
+        self.done = False
+        pass
+
+    def get_iter(self, space, i):
+        arr = self.seq[i]
+        imp = arr.implementation
+        if arr.is_scalar():
+            return ConcreteIter(imp, 1, [], [], [], self.op_flags[i], self)
+        shape = self.shape
+
+        backward = imp.order != self.order
+
+        r = calculate_broadcast_strides(imp.strides, imp.backstrides, 
imp.shape,
+                                        shape, backward)
+
+        iter_shape = shape
+        if len(shape) != len(r[0]):
+            # shape can be shorter when using an external loop, just return a 
view
+            iter_shape = imp.shape
+        return ConcreteIter(imp, imp.get_size(), iter_shape, r[0], r[1],
+                            self.op_flags[i], self)
+
+    def descr_iter(self, space):
+        return space.wrap(self)
+
+    def descr_get_shape(self, space):
+        return space.wrap(self.shape)
+
+    def descr_get_size(self, space):
+        return space.wrap(self.size)
+
+    def descr_get_index(self, space):
+        return space.wrap(self.index)
+
+    @jit.unroll_safe
+    def descr_next(self, space):
+        if self.index >= self.size:
+            self.done = True
+            raise OperationError(space.w_StopIteration, space.w_None)
+        self.index += 1
+        res = []
+        for i, (it, st) in enumerate(self.iters):
+            res.append(self._get_item(it, st))
+            self.iters[i] = (it, it.next(st))
+        if len(res) < 2:
+            return res[0]
+        return space.newtuple(res)
+
+    def _get_item(self, it, st):
+        return W_NDimArray(it.getoperand(st))
+
+
+W_Broadcast.typedef = TypeDef("numpy.broadcast",
+                              
__new__=interp2app(W_Broadcast.descr_new_broadcast),
+                              __iter__=interp2app(W_Broadcast.descr_iter),
+                              next=interp2app(W_Broadcast.descr_next),
+                              
shape=GetSetProperty(W_Broadcast.descr_get_shape),
+                              size=GetSetProperty(W_Broadcast.descr_get_size),
+                              
index=GetSetProperty(W_Broadcast.descr_get_index),
+                              )
diff --git a/pypy/module/micronumpy/test/test_broadcast.py 
b/pypy/module/micronumpy/test/test_broadcast.py
new file mode 100644
--- /dev/null
+++ b/pypy/module/micronumpy/test/test_broadcast.py
@@ -0,0 +1,75 @@
+# -*- encoding: utf-8 -*-
+
+from pypy.module.micronumpy.test.test_base import BaseNumpyAppTest
+
+
+class AppTestArrayBroadcast(BaseNumpyAppTest):
+    def test_broadcast_for_row_and_column(self):
+        import numpy as np
+        x = np.array([[1], [2], [3]])
+        y = np.array([4, 5])
+        b = list(np.broadcast(x, y))
+        assert b == [(1, 4), (1, 5), (2, 4), (2, 5), (3, 4), (3, 5)]
+
+    def test_broadcast_properties(self):
+        import numpy as np
+        x = np.array([[1], [2], [3]])
+        y = np.array([4, 5])
+        b = np.broadcast(x, y)
+
+        assert b.shape == (3, 2)
+        assert b.size == 6
+        assert b.index == 0
+
+        b.next()
+        b.next()
+
+        assert b.shape == (3, 2)
+        assert b.size == 6
+        assert b.index == 2
+
+    def test_broadcast_from_doctest(self):
+        """
+        Test from numpy.broadcast doctest.
+        """
+        import numpy as np
+        x = np.array([[1], [2], [3]])
+        y = np.array([4, 5, 6])
+        reference = np.array([[5., 6., 7.],
+                              [6., 7., 8.],
+                              [7., 8., 9.]])
+
+        b = np.broadcast(x, y)
+        out = np.empty(b.shape)
+        out.flat = [u + v for (u, v) in b]
+
+        assert (reference == out).all()
+        assert out.dtype == reference.dtype
+        assert b.shape == reference.shape
+
+    def test_broadcast_linear(self):
+        import numpy as np
+        x = np.array([1, 2, 3])
+        y = np.array([4, 5, 6])
+        b = list(np.broadcast(x, y))
+        assert b == [(1, 4), (2, 5), (3, 6)]
+        assert b[0][0].dtype == x.dtype
+
+    def test_broadcast_linear_unequal(self):
+        import numpy as np
+        x = np.array([1, 2, 3])
+        y = np.array([4, 5])
+        raises(ValueError, np.broadcast, x, y)
+
+    def test_broadcast_3_args(self):
+        import numpy as np
+        x = np.array([[[1]], [[2]], [[3]]])
+        y = np.array([[[40], [50]]])
+        z = np.array([[[700, 800]]])
+
+        b = list(np.broadcast(x, y, z))
+
+        assert b == [(1, 40, 700), (1, 40, 800), (1, 50, 700), (1, 50, 800),
+                     (2, 40, 700), (2, 40, 800), (2, 50, 700), (2, 50, 800),
+                     (3, 40, 700), (3, 40, 800), (3, 50, 700), (3, 50, 800)]
+
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to