Author: mattip Branch: matrixmath-dot Changeset: r50197:d28b98fc74ed Date: 2011-12-05 23:17 +0200 http://bitbucket.org/pypy/pypy/changeset/d28b98fc74ed/
Log: dot works diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py --- a/pypy/module/micronumpy/interp_numarray.py +++ b/pypy/module/micronumpy/interp_numarray.py @@ -152,7 +152,7 @@ return arr def done(self): - return self.offset == self.size + return self.offset >= self.size def get_offset(self): return self.offset @@ -197,7 +197,7 @@ class BroadcastIterator(BaseIterator): '''Like a view iterator, but will repeatedly access values for all iterations across a res_shape, folding the offset - using mod() arithmetic + using stride = backstride = 0 ''' def __init__(self, arr, res_shape): self.indices = [0] * len(res_shape) @@ -522,7 +522,7 @@ assert other_critical_dim >= 0 out_shape += self.shape[:-1] + \ w_other.shape[0:other_critical_dim] + \ - w_other.shape[other_critical_dim:] + w_other.shape[other_critical_dim + 1:] elif len(w_other.shape) > 0: #dot does not reduce out_shape += self.shape[:-1] @@ -535,25 +535,28 @@ out_ndims = len(out_shape) #TODO: what should the order be? C or F? arr = W_NDimArray(out_size, out_shape, dtype=dtype) - out_iter = ArrayIterator(out_size) + out_iter = ViewIterator(arr) #TODO: invalidate self, w_other with arr ? - me_iter = BroadcastIterator(self, self.shape[:-1] + [1]) - assert other_critical_dim >= 0 - other_iter = BroadcastIterator(self, - w_other.shape[:other_critical_dim] + [1] + \ - w_other.shape[other_critical_dim:]) while not out_iter.done(): - w_ssd = space.newlist([space.wrap(me_iter.get_offset()), - space.wrap(len(self.shape)-1)]) - w_osd = space.newlist([space.wrap(other_iter.get_offset()), + my_index = self.start + other_index = w_other.start + i = 0 + while i < len(self.shape) - 1: + my_index += out_iter.indices[i] * self.strides[i] + i += 1 + for j in range(len(w_other.shape) - 2): + other_index += out_iter.indices[i] * w_other.strides[j] + other_index += out_iter.indices[-1] * w_other.strides[-1] + w_ssd = space.newlist([space.wrap(my_index), + space.wrap(len(self.shape) - 1)]) + w_osd = space.newlist([space.wrap(other_index), space.wrap(other_critical_dim)]) w_res = self.descr_mul1d(space, w_other, w_ssd, w_osd) + assert isinstance(w_res, BaseArray) value = w_res.descr_sum(space) - abc=hgk - arr.setitem(out_iter, value) + arr.setitem(out_iter.get_offset(), value) out_iter = out_iter.next(out_ndims) - me_iter = me_iter.next(0) - other_iter = other_iter.next(0) + ii += 1 return arr def get_concrete(self): @@ -818,7 +821,8 @@ shape[:]) def descr_mean(self, space): - return space.div(self.descr_sumpromote(space), space.wrap(self.find_size())) + return space.div(self.descr_sumpromote(space), + space.wrap(self.find_size())) def descr_nonzero(self, space): if self.find_size() > 1: @@ -940,7 +944,7 @@ shapelen=shapelen, result_size=result_size, i=i, ri=ri, self=self, result=result) - result.dtype.setitem(result.storage, ri.offset, self.eval(i)) + result.dtype.setitem(result.storage, ri.get_offset(), self.eval(i)) i = i.next(shapelen) ri = ri.next(shapelen) return result @@ -1045,6 +1049,14 @@ if res_shape is None: res_shape = self.shape # we still force the shape on children #TODO: use left_start_dim, right_start_dim if they are not [-1, -1] + if self.left_start_dim[0] >= 0: + ldim = self.left_start_dim[1] + rdim = self.right_start_dim[1] + left_iter = OneDimIterator(self.left_start_dim[0], + self.left.strides[ldim], self.left.shape[ldim]) + right_iter = OneDimIterator(self.right_start_dim[0], + self.right.strides[rdim], self.right.shape[rdim]) + return Call2Iterator(left_iter, right_iter) return Call2Iterator(self.left.start_iter(res_shape), self.right.start_iter(res_shape)) @@ -1143,7 +1155,7 @@ self=self, source=source, res_iter=res_iter, source_iter=source_iter) - self.setitem(res_iter.offset, source.eval(source_iter).convert_to( + self.setitem(res_iter.get_offset(), source.eval(source_iter).convert_to( self.find_dtype())) source_iter = source_iter.next(shapelen) res_iter = res_iter.next(shapelen) @@ -1165,7 +1177,7 @@ array = W_NDimArray(self.size, self.shape[:], self.find_dtype()) iter = self.start_iter() while not iter.done(): - array.setitem(iter.offset, self.getitem(iter.offset)) + array.setitem(iter.get_offset(), self.getitem(iter.get_offset())) iter = iter.next(len(self.shape)) return array @@ -1280,7 +1292,7 @@ arr_iter = arr.start_iter(arr.shape) for i in range(len(elems_w)): w_elem = elems_w[i] - dtype.setitem(arr.storage, arr_iter.offset, dtype.coerce(space, w_elem)) + dtype.setitem(arr.storage, arr_iter.get_offset(), dtype.coerce(space, w_elem)) arr_iter = arr_iter.next(shapelen) return arr diff --git a/pypy/module/micronumpy/interp_ufuncs.py b/pypy/module/micronumpy/interp_ufuncs.py --- a/pypy/module/micronumpy/interp_ufuncs.py +++ b/pypy/module/micronumpy/interp_ufuncs.py @@ -140,6 +140,7 @@ def call(self, space, args_w): from pypy.module.micronumpy.interp_numarray import (Call2, convert_to_array, Scalar, shape_agreement) + #TODO: use of w_ssd, w_osd can be optimized. if len(args_w)<4: [w_lhs, w_rhs] = args_w w_ssd = space.newlist([space.wrap(-1)]*2) @@ -166,9 +167,17 @@ new_sig = signature.Signature.find_sig([ self.signature, w_lhs.signature, w_rhs.signature ]) - new_shape = shape_agreement(space, w_lhs.shape, w_rhs.shape) + new_shape = [] + ssd = [space.int_w(s) for s in space.listview(w_ssd)] + osd = [space.int_w(s) for s in space.listview(w_osd)] + if ssd[0]<0: + new_shape = shape_agreement(space, w_lhs.shape, w_rhs.shape) + else: + #Assumption (should have been checked in call): + #w_lhs.shape[ssd[1]] == w_rhs.shape[osd[1]] + new_shape = [w_lhs.shape[ssd[1]]] w_res = Call2(new_sig, new_shape, calc_dtype, - res_dtype, w_lhs, w_rhs, w_ssd, w_osd) + res_dtype, w_lhs, w_rhs, ssd, osd) w_lhs.add_invalidates(w_res) w_rhs.add_invalidates(w_res) return w_res diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py --- a/pypy/module/micronumpy/test/test_numarray.py +++ b/pypy/module/micronumpy/test/test_numarray.py @@ -718,6 +718,12 @@ assert a.dot(range(5)) == 30 assert dot(range(5), range(5)) == 30 assert (dot(5, [1, 2, 3]) == [5, 10, 15]).all() + + a = array([range(4), range(4, 8), range(8, 12)]) + b = array([range(3), range(3, 6), range(6, 9), range(9, 12)]) + c = a.dot(b) + assert (c == [[ 42, 48, 54], [114, 136, 158], [186, 224, 262]]).all() + a = array([[range(4), range(4, 8), range(8, 12)], [range(12, 16), range(16, 20), range(20, 24)]]) raises(ValueError, "a.dot(a)") _______________________________________________ pypy-commit mailing list pypy-commit@python.org http://mail.python.org/mailman/listinfo/pypy-commit