Author: Maciej Fijalkowski <[email protected]>
Branch: numpy-refactor
Changeset: r57073:5761558dcdcb
Date: 2012-09-01 22:34 +0200
http://bitbucket.org/pypy/pypy/changeset/5761558dcdcb/
Log: Pass enough around to start implementing broadcasting
diff --git a/pypy/module/micronumpy/arrayimpl/concrete.py
b/pypy/module/micronumpy/arrayimpl/concrete.py
--- a/pypy/module/micronumpy/arrayimpl/concrete.py
+++ b/pypy/module/micronumpy/arrayimpl/concrete.py
@@ -1,7 +1,7 @@
from pypy.module.micronumpy.arrayimpl import base
from pypy.module.micronumpy import support, loop
-from pypy.module.micronumpy.strides import calc_new_strides
+from pypy.module.micronumpy.strides import calc_new_strides, shape_agreement
from pypy.module.micronumpy.iter import Chunk, Chunks, NewAxisChunk,
RecordChunk
from pypy.interpreter.error import OperationError, operationerrfmt
from pypy.rlib import jit
@@ -111,7 +111,8 @@
def get_shape(self):
return self.shape
- def create_iter(self):
+ def create_iter(self, shape):
+ assert shape == self.shape
return ConcreteArrayIterator(self)
def getitem(self, index):
@@ -125,16 +126,17 @@
def copy(self):
impl = ConcreteArray(self.shape, self.dtype, self.order)
- return loop.setslice(impl, self)
+ return loop.setslice(self.shape, impl, self)
- def setslice(self, arr):
- if arr.is_scalar():
- self.fill(arr.get_scalar_value())
+ def setslice(self, space, arr):
+ impl = arr.implementation
+ if impl.is_scalar():
+ self.fill(impl.get_scalar_value())
return
- assert isinstance(arr, ConcreteArray)
- if arr.storage == self.storage:
- arr = arr.copy()
- loop.setslice(self, arr)
+ shape = shape_agreement(space, self.shape, arr)
+ if impl.storage == self.storage:
+ impl = impl.copy()
+ loop.setslice(shape, self, impl)
def get_size(self):
return self.size // self.dtype.itemtype.get_element_size()
@@ -247,7 +249,7 @@
w_value = support.convert_to_array(space, w_value)
chunks = self._prepare_slice_args(space, w_index)
view = chunks.apply(self)
- view.implementation.setslice(w_value.implementation)
+ view.implementation.setslice(space, w_value)
def transpose(self):
if len(self.shape) < 2:
@@ -279,7 +281,8 @@
def fill(self, box):
loop.fill(self, box)
- def create_iter(self):
+ def create_iter(self, shape):
+ assert shape == self.shape
if len(self.shape) == 1:
return OneDimViewIterator(self)
return MultiDimViewIterator(self)
diff --git a/pypy/module/micronumpy/arrayimpl/scalar.py
b/pypy/module/micronumpy/arrayimpl/scalar.py
--- a/pypy/module/micronumpy/arrayimpl/scalar.py
+++ b/pypy/module/micronumpy/arrayimpl/scalar.py
@@ -28,7 +28,7 @@
def get_shape(self):
return []
- def create_iter(self):
+ def create_iter(self, shape):
return ScalarIterator(self.value)
def set_scalar_value(self, value):
diff --git a/pypy/module/micronumpy/interp_numarray.py
b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -91,7 +91,7 @@
#return space.call_function(cache.w_array_repr, self)
def dump_data(self):
- i = self.create_iter()
+ i = self.create_iter(self.get_shape())
first = True
dtype = self.get_dtype()
s = StringBuilder()
@@ -106,8 +106,8 @@
s.append('])')
return s.build()
- def create_iter(self):
- return self.implementation.create_iter()
+ def create_iter(self, shape):
+ return self.implementation.create_iter(shape)
def is_scalar(self):
return self.implementation.is_scalar()
@@ -348,7 +348,7 @@
if ndmin > len(shape):
shape = [1] * (ndmin - len(shape)) + shape
arr = W_NDimArray(shape, dtype, order=order)
- arr_iter = arr.create_iter()
+ arr_iter = arr.create_iter(arr.get_shape())
for w_elem in elems_w:
arr_iter.setitem(dtype.coerce(space, w_elem))
arr_iter.next()
diff --git a/pypy/module/micronumpy/interp_ufuncs.py
b/pypy/module/micronumpy/interp_ufuncs.py
--- a/pypy/module/micronumpy/interp_ufuncs.py
+++ b/pypy/module/micronumpy/interp_ufuncs.py
@@ -245,7 +245,7 @@
w_obj.get_dtype(),
promote_to_float=self.promote_to_float,
promote_bools=self.promote_bools)
- if out:
+ if out is not None:
if not isinstance(out, W_NDimArray):
raise OperationError(space.w_TypeError, space.wrap(
'output must be an array'))
@@ -264,20 +264,8 @@
else:
out = arr
return space.wrap(out)
- if not out:
- out = W_NDimArray(w_obj.get_shape(), res_dtype)
- else:
- assert isinstance(out, W_NDimArray) # For translation
- broadcast_shape = shape_agreement(space, w_obj.get_shape(),
- out.get_shape())
- if not broadcast_shape or broadcast_shape != out.get_shape():
- raise operationerrfmt(space.w_ValueError,
- 'output parameter shape mismatch, could not broadcast
[%s]' +
- ' to [%s]',
- ",".join([str(x) for x in w_obj.get_shape()]),
- ",".join([str(x) for x in out.get_shape()]),
- )
- return loop.call1(self.func, self.name, calc_dtype, res_dtype,
+ shape = shape_agreement(space, w_obj.get_shape(), out)
+ return loop.call1(shape, self.func, self.name, calc_dtype, res_dtype,
w_obj, out)
@@ -341,19 +329,9 @@
else:
out = arr
return space.wrap(out)
- new_shape = shape_agreement(space, w_lhs.get_shape(),
- w_rhs.get_shape())
- # Test correctness of out.shape
- if out and out.shape != shape_agreement(space, new_shape, out.shape):
- raise operationerrfmt(space.w_ValueError,
- 'output parameter shape mismatch, could not broadcast [%s]' +
- ' to [%s]',
- ",".join([str(x) for x in new_shape]),
- ",".join([str(x) for x in out.shape]),
- )
- if out is None:
- out = W_NDimArray(new_shape, res_dtype)
- return loop.call2(self.func, self.name, calc_dtype,
+ new_shape = shape_agreement(space, w_lhs.get_shape(), w_rhs)
+ new_shape = shape_agreement(space, new_shape, out)
+ return loop.call2(new_shape, self.func, self.name, calc_dtype,
res_dtype, w_lhs, w_rhs, out)
diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -3,10 +3,14 @@
signatures
"""
-def call2(func, name, calc_dtype, res_dtype, w_lhs, w_rhs, out):
- left_iter = w_lhs.create_iter()
- right_iter = w_rhs.create_iter()
- out_iter = out.create_iter()
+from pypy.module.micronumpy.support import create_array
+
+def call2(shape, func, name, calc_dtype, res_dtype, w_lhs, w_rhs, out):
+ if out is None:
+ out = create_array(shape, res_dtype)
+ left_iter = w_lhs.create_iter(shape)
+ right_iter = w_rhs.create_iter(shape)
+ out_iter = out.create_iter(shape)
while not out_iter.done():
w_left = left_iter.getitem().convert_to(calc_dtype)
w_right = right_iter.getitem().convert_to(calc_dtype)
@@ -17,9 +21,11 @@
out_iter.next()
return out
-def call1(func, name , calc_dtype, res_dtype, w_obj, out):
- obj_iter = w_obj.create_iter()
- out_iter = out.create_iter()
+def call1(shape, func, name , calc_dtype, res_dtype, w_obj, out):
+ if out is None:
+ out = create_array(shape, res_dtype)
+ obj_iter = w_obj.create_iter(shape)
+ out_iter = out.create_iter(shape)
while not out_iter.done():
elem = obj_iter.getitem().convert_to(calc_dtype)
out_iter.setitem(func(calc_dtype, elem).convert_to(res_dtype))
@@ -27,10 +33,12 @@
obj_iter.next()
return out
-def setslice(target, source):
- target_iter = target.create_iter()
+def setslice(shape, target, source):
+ # note that unlike everything else, target and source here are
+ # array implementations, not arrays
+ target_iter = target.create_iter(shape)
+ source_iter = source.create_iter(shape)
dtype = target.dtype
- source_iter = source.create_iter()
while not target_iter.done():
target_iter.setitem(source_iter.getitem().convert_to(dtype))
target_iter.next()
@@ -38,7 +46,7 @@
return target
def compute_reduce(obj, calc_dtype, func, done_func, identity):
- obj_iter = obj.create_iter()
+ obj_iter = obj.create_iter(obj.get_shape())
if identity is None:
cur_value = obj_iter.getitem().convert_to(calc_dtype)
obj_iter.next()
@@ -53,7 +61,7 @@
return cur_value
def fill(arr, box):
- arr_iter = arr.create_iter()
+ arr_iter = arr.create_iter(arr.get_shape())
while not arr_iter.done():
arr_iter.setitem(box)
arr_iter.next()
diff --git a/pypy/module/micronumpy/strides.py
b/pypy/module/micronumpy/strides.py
--- a/pypy/module/micronumpy/strides.py
+++ b/pypy/module/micronumpy/strides.py
@@ -104,7 +104,13 @@
i //= shape[s]
return coords, step, lngth
-def shape_agreement(space, shape1, shape2):
+def shape_agreement(space, shape1, w_arr2):
+ from pypy.module.micronumpy.interp_numarray import W_NDimArray
+
+ if w_arr2 is None:
+ return shape1
+ assert isinstance(w_arr2, W_NDimArray)
+ shape2 = w_arr2.get_shape()
ret = _shape_agreement(shape1, shape2)
if len(ret) < max(len(shape1), len(shape2)):
raise OperationError(space.w_ValueError,
diff --git a/pypy/module/micronumpy/test/test_numarray.py
b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -138,15 +138,13 @@
assert s.backstrides == [-120, 12, 2]
def test_shape_agreement(self):
- from pypy.module.micronumpy.strides import shape_agreement
- assert shape_agreement(self.space, [3], [3]) == [3]
- assert shape_agreement(self.space, [1, 2, 3], [1, 2, 3]) == [1, 2, 3]
- py.test.raises(OperationError, shape_agreement, self.space, [2], [3])
- assert shape_agreement(self.space, [4, 4], []) == [4, 4]
- assert shape_agreement(self.space,
- [8, 1, 6, 1], [7, 1, 5]) == [8, 7, 6, 5]
- assert shape_agreement(self.space,
- [5, 2], [4, 3, 5, 2]) == [4, 3, 5, 2]
+ from pypy.module.micronumpy.strides import _shape_agreement
+ assert _shape_agreement([3], [3]) == [3]
+ assert _shape_agreement([1, 2, 3], [1, 2, 3]) == [1, 2, 3]
+ _shape_agreement([2], [3]) == 0
+ assert _shape_agreement([4, 4], []) == [4, 4]
+ assert _shape_agreement([8, 1, 6, 1], [7, 1, 5]) == [8, 7, 6, 5]
+ assert _shape_agreement([5, 2], [4, 3, 5, 2]) == [4, 3, 5, 2]
def test_calc_new_strides(self):
from pypy.module.micronumpy.strides import calc_new_strides
_______________________________________________
pypy-commit mailing list
[email protected]
http://mail.python.org/mailman/listinfo/pypy-commit