Author: Maciej Fijalkowski <[email protected]>
Branch: numpy-refactor
Changeset: r57074:8aa0985313cd
Date: 2012-09-01 23:11 +0200
http://bitbucket.org/pypy/pypy/changeset/8aa0985313cd/

Log:    implement broadcasting

diff --git a/pypy/module/micronumpy/arrayimpl/concrete.py 
b/pypy/module/micronumpy/arrayimpl/concrete.py
--- a/pypy/module/micronumpy/arrayimpl/concrete.py
+++ b/pypy/module/micronumpy/arrayimpl/concrete.py
@@ -1,7 +1,8 @@
 
 from pypy.module.micronumpy.arrayimpl import base
 from pypy.module.micronumpy import support, loop
-from pypy.module.micronumpy.strides import calc_new_strides, shape_agreement
+from pypy.module.micronumpy.strides import calc_new_strides, shape_agreement,\
+     calculate_broadcast_strides
 from pypy.module.micronumpy.iter import Chunk, Chunks, NewAxisChunk, 
RecordChunk
 from pypy.interpreter.error import OperationError, operationerrfmt
 from pypy.rlib import jit
@@ -43,15 +44,15 @@
         return self.index >= self.size
 
 class MultiDimViewIterator(ConcreteArrayIterator):
-    def __init__(self, array):
-        self.indexes = [0] * len(array.shape)
+    def __init__(self, array, start, strides, backstrides, shape):
+        self.indexes = [0] * len(shape)
         self.array = array
-        self.shape = array.shape
-        self.offset = array.start
-        self.shapelen = len(self.shape)
+        self.shape = shape
+        self.offset = start
+        self.shapelen = len(shape)
         self._done = False
-        self.strides = array.strides
-        self.backstrides = array.backstrides
+        self.strides = strides
+        self.backstrides = backstrides
 
     @jit.unroll_safe
     def next(self):
@@ -112,8 +113,11 @@
         return self.shape
 
     def create_iter(self, shape):
-        assert shape == self.shape
-        return ConcreteArrayIterator(self)
+        if shape == self.shape:
+            return ConcreteArrayIterator(self)
+        r = calculate_broadcast_strides(self.strides, self.backstrides,
+                                        self.shape, shape)
+        return MultiDimViewIterator(self, 0, r[0], r[1], shape)
 
     def getitem(self, index):
         return self.dtype.getitem(self, index)
@@ -282,7 +286,12 @@
         loop.fill(self, box)
 
     def create_iter(self, shape):
-        assert shape == self.shape
+        if shape != self.shape:
+            r = calculate_broadcast_strides(self.strides, self.backstrides,
+                                            self.shape, shape)
+            return MultiDimViewIterator(self.parent,
+                                        self.start, r[0], r[1], shape)
         if len(self.shape) == 1:
             return OneDimViewIterator(self)
-        return MultiDimViewIterator(self)
+        return MultiDimViewIterator(self.parent, self.start, self.strides,
+                                    self.backstrides, self.shape)
_______________________________________________
pypy-commit mailing list
[email protected]
http://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to