Author: Maciej Fijalkowski <[email protected]>
Branch: numpy-refactor
Changeset: r57073:5761558dcdcb
Date: 2012-09-01 22:34 +0200
http://bitbucket.org/pypy/pypy/changeset/5761558dcdcb/

Log:    Pass enough around to start implementing broadcasting

diff --git a/pypy/module/micronumpy/arrayimpl/concrete.py 
b/pypy/module/micronumpy/arrayimpl/concrete.py
--- a/pypy/module/micronumpy/arrayimpl/concrete.py
+++ b/pypy/module/micronumpy/arrayimpl/concrete.py
@@ -1,7 +1,7 @@
 
 from pypy.module.micronumpy.arrayimpl import base
 from pypy.module.micronumpy import support, loop
-from pypy.module.micronumpy.strides import calc_new_strides
+from pypy.module.micronumpy.strides import calc_new_strides, shape_agreement
 from pypy.module.micronumpy.iter import Chunk, Chunks, NewAxisChunk, 
RecordChunk
 from pypy.interpreter.error import OperationError, operationerrfmt
 from pypy.rlib import jit
@@ -111,7 +111,8 @@
     def get_shape(self):
         return self.shape
 
-    def create_iter(self):
+    def create_iter(self, shape):
+        assert shape == self.shape
         return ConcreteArrayIterator(self)
 
     def getitem(self, index):
@@ -125,16 +126,17 @@
 
     def copy(self):
         impl = ConcreteArray(self.shape, self.dtype, self.order)
-        return loop.setslice(impl, self)
+        return loop.setslice(self.shape, impl, self)
 
-    def setslice(self, arr):
-        if arr.is_scalar():
-            self.fill(arr.get_scalar_value())
+    def setslice(self, space, arr):
+        impl = arr.implementation
+        if impl.is_scalar():
+            self.fill(impl.get_scalar_value())
             return
-        assert isinstance(arr, ConcreteArray)
-        if arr.storage == self.storage:
-            arr = arr.copy()
-        loop.setslice(self, arr)
+        shape = shape_agreement(space, self.shape, arr)
+        if impl.storage == self.storage:
+            impl = impl.copy()
+        loop.setslice(shape, self, impl)
 
     def get_size(self):
         return self.size // self.dtype.itemtype.get_element_size()
@@ -247,7 +249,7 @@
             w_value = support.convert_to_array(space, w_value)
             chunks = self._prepare_slice_args(space, w_index)
             view = chunks.apply(self)
-            view.implementation.setslice(w_value.implementation)
+            view.implementation.setslice(space, w_value)
 
     def transpose(self):
         if len(self.shape) < 2:
@@ -279,7 +281,8 @@
     def fill(self, box):
         loop.fill(self, box)
 
-    def create_iter(self):
+    def create_iter(self, shape):
+        assert shape == self.shape
         if len(self.shape) == 1:
             return OneDimViewIterator(self)
         return MultiDimViewIterator(self)
diff --git a/pypy/module/micronumpy/arrayimpl/scalar.py 
b/pypy/module/micronumpy/arrayimpl/scalar.py
--- a/pypy/module/micronumpy/arrayimpl/scalar.py
+++ b/pypy/module/micronumpy/arrayimpl/scalar.py
@@ -28,7 +28,7 @@
     def get_shape(self):
         return []
 
-    def create_iter(self):
+    def create_iter(self, shape):
         return ScalarIterator(self.value)
 
     def set_scalar_value(self, value):
diff --git a/pypy/module/micronumpy/interp_numarray.py 
b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -91,7 +91,7 @@
         #return space.call_function(cache.w_array_repr, self)
 
     def dump_data(self):
-        i = self.create_iter()
+        i = self.create_iter(self.get_shape())
         first = True
         dtype = self.get_dtype()
         s = StringBuilder()
@@ -106,8 +106,8 @@
         s.append('])')
         return s.build()
 
-    def create_iter(self):
-        return self.implementation.create_iter()
+    def create_iter(self, shape):
+        return self.implementation.create_iter(shape)
 
     def is_scalar(self):
         return self.implementation.is_scalar()
@@ -348,7 +348,7 @@
     if ndmin > len(shape):
         shape = [1] * (ndmin - len(shape)) + shape
     arr = W_NDimArray(shape, dtype, order=order)
-    arr_iter = arr.create_iter()
+    arr_iter = arr.create_iter(arr.get_shape())
     for w_elem in elems_w:
         arr_iter.setitem(dtype.coerce(space, w_elem))
         arr_iter.next()
diff --git a/pypy/module/micronumpy/interp_ufuncs.py 
b/pypy/module/micronumpy/interp_ufuncs.py
--- a/pypy/module/micronumpy/interp_ufuncs.py
+++ b/pypy/module/micronumpy/interp_ufuncs.py
@@ -245,7 +245,7 @@
                                   w_obj.get_dtype(),
                                   promote_to_float=self.promote_to_float,
                                   promote_bools=self.promote_bools)
-        if out:
+        if out is not None:
             if not isinstance(out, W_NDimArray):
                 raise OperationError(space.w_TypeError, space.wrap(
                                                 'output must be an array'))
@@ -264,20 +264,8 @@
             else:
                 out = arr
             return space.wrap(out)
-        if not out:
-            out = W_NDimArray(w_obj.get_shape(), res_dtype)
-        else:
-            assert isinstance(out, W_NDimArray) # For translation
-            broadcast_shape =  shape_agreement(space, w_obj.get_shape(),
-                                               out.get_shape())
-            if not broadcast_shape or broadcast_shape != out.get_shape():
-                raise operationerrfmt(space.w_ValueError,
-                    'output parameter shape mismatch, could not broadcast 
[%s]' +
-                    ' to [%s]',
-                    ",".join([str(x) for x in w_obj.get_shape()]),
-                    ",".join([str(x) for x in out.get_shape()]),
-                    )
-        return loop.call1(self.func, self.name, calc_dtype, res_dtype,
+        shape =  shape_agreement(space, w_obj.get_shape(), out)
+        return loop.call1(shape, self.func, self.name, calc_dtype, res_dtype,
                           w_obj, out)
 
 
@@ -341,19 +329,9 @@
             else:
                 out = arr
             return space.wrap(out)
-        new_shape = shape_agreement(space, w_lhs.get_shape(),
-                                    w_rhs.get_shape())
-        # Test correctness of out.shape
-        if out and out.shape != shape_agreement(space, new_shape, out.shape):
-            raise operationerrfmt(space.w_ValueError,
-                'output parameter shape mismatch, could not broadcast [%s]' +
-                ' to [%s]',
-                ",".join([str(x) for x in new_shape]),
-                ",".join([str(x) for x in out.shape]),
-                )
-        if out is None:
-            out = W_NDimArray(new_shape, res_dtype)
-        return loop.call2(self.func, self.name, calc_dtype,
+        new_shape = shape_agreement(space, w_lhs.get_shape(), w_rhs)
+        new_shape = shape_agreement(space, new_shape, out)
+        return loop.call2(new_shape, self.func, self.name, calc_dtype,
                           res_dtype, w_lhs, w_rhs, out)
 
 
diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -3,10 +3,14 @@
 signatures
 """
 
-def call2(func, name, calc_dtype, res_dtype, w_lhs, w_rhs, out):
-    left_iter = w_lhs.create_iter()
-    right_iter = w_rhs.create_iter()
-    out_iter = out.create_iter()
+from pypy.module.micronumpy.support import create_array
+
+def call2(shape, func, name, calc_dtype, res_dtype, w_lhs, w_rhs, out):
+    if out is None:
+        out = create_array(shape, res_dtype)
+    left_iter = w_lhs.create_iter(shape)
+    right_iter = w_rhs.create_iter(shape)
+    out_iter = out.create_iter(shape)
     while not out_iter.done():
         w_left = left_iter.getitem().convert_to(calc_dtype)
         w_right = right_iter.getitem().convert_to(calc_dtype)
@@ -17,9 +21,11 @@
         out_iter.next()
     return out
 
-def call1(func, name , calc_dtype, res_dtype, w_obj, out):
-    obj_iter = w_obj.create_iter()
-    out_iter = out.create_iter()
+def call1(shape, func, name , calc_dtype, res_dtype, w_obj, out):
+    if out is None:
+        out = create_array(shape, res_dtype)
+    obj_iter = w_obj.create_iter(shape)
+    out_iter = out.create_iter(shape)
     while not out_iter.done():
         elem = obj_iter.getitem().convert_to(calc_dtype)
         out_iter.setitem(func(calc_dtype, elem).convert_to(res_dtype))
@@ -27,10 +33,12 @@
         obj_iter.next()
     return out
 
-def setslice(target, source):
-    target_iter = target.create_iter()
+def setslice(shape, target, source):
+    # note that unlike everything else, target and source here are
+    # array implementations, not arrays
+    target_iter = target.create_iter(shape)
+    source_iter = source.create_iter(shape)
     dtype = target.dtype
-    source_iter = source.create_iter()
     while not target_iter.done():
         target_iter.setitem(source_iter.getitem().convert_to(dtype))
         target_iter.next()
@@ -38,7 +46,7 @@
     return target
 
 def compute_reduce(obj, calc_dtype, func, done_func, identity):
-    obj_iter = obj.create_iter()
+    obj_iter = obj.create_iter(obj.get_shape())
     if identity is None:
         cur_value = obj_iter.getitem().convert_to(calc_dtype)
         obj_iter.next()
@@ -53,7 +61,7 @@
     return cur_value
 
 def fill(arr, box):
-    arr_iter = arr.create_iter()
+    arr_iter = arr.create_iter(arr.get_shape())
     while not arr_iter.done():
         arr_iter.setitem(box)
         arr_iter.next()
diff --git a/pypy/module/micronumpy/strides.py 
b/pypy/module/micronumpy/strides.py
--- a/pypy/module/micronumpy/strides.py
+++ b/pypy/module/micronumpy/strides.py
@@ -104,7 +104,13 @@
             i //= shape[s]
     return coords, step, lngth
 
-def shape_agreement(space, shape1, shape2):
+def shape_agreement(space, shape1, w_arr2):
+    from pypy.module.micronumpy.interp_numarray import W_NDimArray
+
+    if w_arr2 is None:
+        return shape1
+    assert isinstance(w_arr2, W_NDimArray)
+    shape2 = w_arr2.get_shape()
     ret = _shape_agreement(shape1, shape2)
     if len(ret) < max(len(shape1), len(shape2)):
         raise OperationError(space.w_ValueError,
diff --git a/pypy/module/micronumpy/test/test_numarray.py 
b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -138,15 +138,13 @@
         assert s.backstrides == [-120, 12, 2]
 
     def test_shape_agreement(self):
-        from pypy.module.micronumpy.strides import shape_agreement
-        assert shape_agreement(self.space, [3], [3]) == [3]
-        assert shape_agreement(self.space, [1, 2, 3], [1, 2, 3]) == [1, 2, 3]
-        py.test.raises(OperationError, shape_agreement, self.space, [2], [3])
-        assert shape_agreement(self.space, [4, 4], []) == [4, 4]
-        assert shape_agreement(self.space,
-                [8, 1, 6, 1], [7, 1, 5]) == [8, 7, 6, 5]
-        assert shape_agreement(self.space,
-                [5, 2], [4, 3, 5, 2]) == [4, 3, 5, 2]
+        from pypy.module.micronumpy.strides import _shape_agreement
+        assert _shape_agreement([3], [3]) == [3]
+        assert _shape_agreement([1, 2, 3], [1, 2, 3]) == [1, 2, 3]
+        _shape_agreement([2], [3]) == 0
+        assert _shape_agreement([4, 4], []) == [4, 4]
+        assert _shape_agreement([8, 1, 6, 1], [7, 1, 5]) == [8, 7, 6, 5]
+        assert _shape_agreement([5, 2], [4, 3, 5, 2]) == [4, 3, 5, 2]
 
     def test_calc_new_strides(self):
         from pypy.module.micronumpy.strides import calc_new_strides
_______________________________________________
pypy-commit mailing list
[email protected]
http://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to