Author: Maciej Fijalkowski <[email protected]>
Branch: numpy-refactor
Changeset: r57074:8aa0985313cd
Date: 2012-09-01 23:11 +0200
http://bitbucket.org/pypy/pypy/changeset/8aa0985313cd/
Log: implement broadcasting
diff --git a/pypy/module/micronumpy/arrayimpl/concrete.py
b/pypy/module/micronumpy/arrayimpl/concrete.py
--- a/pypy/module/micronumpy/arrayimpl/concrete.py
+++ b/pypy/module/micronumpy/arrayimpl/concrete.py
@@ -1,7 +1,8 @@
from pypy.module.micronumpy.arrayimpl import base
from pypy.module.micronumpy import support, loop
-from pypy.module.micronumpy.strides import calc_new_strides, shape_agreement
+from pypy.module.micronumpy.strides import calc_new_strides, shape_agreement,\
+ calculate_broadcast_strides
from pypy.module.micronumpy.iter import Chunk, Chunks, NewAxisChunk,
RecordChunk
from pypy.interpreter.error import OperationError, operationerrfmt
from pypy.rlib import jit
@@ -43,15 +44,15 @@
return self.index >= self.size
class MultiDimViewIterator(ConcreteArrayIterator):
- def __init__(self, array):
- self.indexes = [0] * len(array.shape)
+ def __init__(self, array, start, strides, backstrides, shape):
+ self.indexes = [0] * len(shape)
self.array = array
- self.shape = array.shape
- self.offset = array.start
- self.shapelen = len(self.shape)
+ self.shape = shape
+ self.offset = start
+ self.shapelen = len(shape)
self._done = False
- self.strides = array.strides
- self.backstrides = array.backstrides
+ self.strides = strides
+ self.backstrides = backstrides
@jit.unroll_safe
def next(self):
@@ -112,8 +113,11 @@
return self.shape
def create_iter(self, shape):
- assert shape == self.shape
- return ConcreteArrayIterator(self)
+ if shape == self.shape:
+ return ConcreteArrayIterator(self)
+ r = calculate_broadcast_strides(self.strides, self.backstrides,
+ self.shape, shape)
+ return MultiDimViewIterator(self, 0, r[0], r[1], shape)
def getitem(self, index):
return self.dtype.getitem(self, index)
@@ -282,7 +286,12 @@
loop.fill(self, box)
def create_iter(self, shape):
- assert shape == self.shape
+ if shape != self.shape:
+ r = calculate_broadcast_strides(self.strides, self.backstrides,
+ self.shape, shape)
+ return MultiDimViewIterator(self.parent,
+ self.start, r[0], r[1], shape)
if len(self.shape) == 1:
return OneDimViewIterator(self)
- return MultiDimViewIterator(self)
+ return MultiDimViewIterator(self.parent, self.start, self.strides,
+ self.backstrides, self.shape)
_______________________________________________
pypy-commit mailing list
[email protected]
http://mail.python.org/mailman/listinfo/pypy-commit