Author: mattip
Branch: matrixmath-dot
Changeset: r50197:d28b98fc74ed
Date: 2011-12-05 23:17 +0200
http://bitbucket.org/pypy/pypy/changeset/d28b98fc74ed/
Log: dot works
diff --git a/pypy/module/micronumpy/interp_numarray.py
b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -152,7 +152,7 @@
return arr
def done(self):
- return self.offset == self.size
+ return self.offset >= self.size
def get_offset(self):
return self.offset
@@ -197,7 +197,7 @@
class BroadcastIterator(BaseIterator):
'''Like a view iterator, but will repeatedly access values
for all iterations across a res_shape, folding the offset
- using mod() arithmetic
+ using stride = backstride = 0
'''
def __init__(self, arr, res_shape):
self.indices = [0] * len(res_shape)
@@ -522,7 +522,7 @@
assert other_critical_dim >= 0
out_shape += self.shape[:-1] + \
w_other.shape[0:other_critical_dim] + \
- w_other.shape[other_critical_dim:]
+ w_other.shape[other_critical_dim + 1:]
elif len(w_other.shape) > 0:
#dot does not reduce
out_shape += self.shape[:-1]
@@ -535,25 +535,28 @@
out_ndims = len(out_shape)
#TODO: what should the order be? C or F?
arr = W_NDimArray(out_size, out_shape, dtype=dtype)
- out_iter = ArrayIterator(out_size)
+ out_iter = ViewIterator(arr)
#TODO: invalidate self, w_other with arr ?
- me_iter = BroadcastIterator(self, self.shape[:-1] + [1])
- assert other_critical_dim >= 0
- other_iter = BroadcastIterator(self,
- w_other.shape[:other_critical_dim] + [1] + \
- w_other.shape[other_critical_dim:])
while not out_iter.done():
- w_ssd = space.newlist([space.wrap(me_iter.get_offset()),
- space.wrap(len(self.shape)-1)])
- w_osd = space.newlist([space.wrap(other_iter.get_offset()),
+ my_index = self.start
+ other_index = w_other.start
+ i = 0
+ while i < len(self.shape) - 1:
+ my_index += out_iter.indices[i] * self.strides[i]
+ i += 1
+ for j in range(len(w_other.shape) - 2):
+ other_index += out_iter.indices[i] * w_other.strides[j]
+ other_index += out_iter.indices[-1] * w_other.strides[-1]
+ w_ssd = space.newlist([space.wrap(my_index),
+ space.wrap(len(self.shape) - 1)])
+ w_osd = space.newlist([space.wrap(other_index),
space.wrap(other_critical_dim)])
w_res = self.descr_mul1d(space, w_other, w_ssd, w_osd)
+ assert isinstance(w_res, BaseArray)
value = w_res.descr_sum(space)
- abc=hgk
- arr.setitem(out_iter, value)
+ arr.setitem(out_iter.get_offset(), value)
out_iter = out_iter.next(out_ndims)
- me_iter = me_iter.next(0)
- other_iter = other_iter.next(0)
+ ii += 1
return arr
def get_concrete(self):
@@ -818,7 +821,8 @@
shape[:])
def descr_mean(self, space):
- return space.div(self.descr_sumpromote(space),
space.wrap(self.find_size()))
+ return space.div(self.descr_sumpromote(space),
+ space.wrap(self.find_size()))
def descr_nonzero(self, space):
if self.find_size() > 1:
@@ -940,7 +944,7 @@
shapelen=shapelen,
result_size=result_size, i=i, ri=ri,
self=self, result=result)
- result.dtype.setitem(result.storage, ri.offset, self.eval(i))
+ result.dtype.setitem(result.storage, ri.get_offset(), self.eval(i))
i = i.next(shapelen)
ri = ri.next(shapelen)
return result
@@ -1045,6 +1049,14 @@
if res_shape is None:
res_shape = self.shape # we still force the shape on children
#TODO: use left_start_dim, right_start_dim if they are not [-1, -1]
+ if self.left_start_dim[0] >= 0:
+ ldim = self.left_start_dim[1]
+ rdim = self.right_start_dim[1]
+ left_iter = OneDimIterator(self.left_start_dim[0],
+ self.left.strides[ldim], self.left.shape[ldim])
+ right_iter = OneDimIterator(self.right_start_dim[0],
+ self.right.strides[rdim], self.right.shape[rdim])
+ return Call2Iterator(left_iter, right_iter)
return Call2Iterator(self.left.start_iter(res_shape),
self.right.start_iter(res_shape))
@@ -1143,7 +1155,7 @@
self=self, source=source,
res_iter=res_iter,
source_iter=source_iter)
- self.setitem(res_iter.offset, source.eval(source_iter).convert_to(
+ self.setitem(res_iter.get_offset(),
source.eval(source_iter).convert_to(
self.find_dtype()))
source_iter = source_iter.next(shapelen)
res_iter = res_iter.next(shapelen)
@@ -1165,7 +1177,7 @@
array = W_NDimArray(self.size, self.shape[:], self.find_dtype())
iter = self.start_iter()
while not iter.done():
- array.setitem(iter.offset, self.getitem(iter.offset))
+ array.setitem(iter.get_offset(), self.getitem(iter.get_offset()))
iter = iter.next(len(self.shape))
return array
@@ -1280,7 +1292,7 @@
arr_iter = arr.start_iter(arr.shape)
for i in range(len(elems_w)):
w_elem = elems_w[i]
- dtype.setitem(arr.storage, arr_iter.offset, dtype.coerce(space,
w_elem))
+ dtype.setitem(arr.storage, arr_iter.get_offset(), dtype.coerce(space,
w_elem))
arr_iter = arr_iter.next(shapelen)
return arr
diff --git a/pypy/module/micronumpy/interp_ufuncs.py
b/pypy/module/micronumpy/interp_ufuncs.py
--- a/pypy/module/micronumpy/interp_ufuncs.py
+++ b/pypy/module/micronumpy/interp_ufuncs.py
@@ -140,6 +140,7 @@
def call(self, space, args_w):
from pypy.module.micronumpy.interp_numarray import (Call2,
convert_to_array, Scalar, shape_agreement)
+ #TODO: use of w_ssd, w_osd can be optimized.
if len(args_w)<4:
[w_lhs, w_rhs] = args_w
w_ssd = space.newlist([space.wrap(-1)]*2)
@@ -166,9 +167,17 @@
new_sig = signature.Signature.find_sig([
self.signature, w_lhs.signature, w_rhs.signature
])
- new_shape = shape_agreement(space, w_lhs.shape, w_rhs.shape)
+ new_shape = []
+ ssd = [space.int_w(s) for s in space.listview(w_ssd)]
+ osd = [space.int_w(s) for s in space.listview(w_osd)]
+ if ssd[0]<0:
+ new_shape = shape_agreement(space, w_lhs.shape, w_rhs.shape)
+ else:
+ #Assumption (should have been checked in call):
+ #w_lhs.shape[ssd[1]] == w_rhs.shape[osd[1]]
+ new_shape = [w_lhs.shape[ssd[1]]]
w_res = Call2(new_sig, new_shape, calc_dtype,
- res_dtype, w_lhs, w_rhs, w_ssd, w_osd)
+ res_dtype, w_lhs, w_rhs, ssd, osd)
w_lhs.add_invalidates(w_res)
w_rhs.add_invalidates(w_res)
return w_res
diff --git a/pypy/module/micronumpy/test/test_numarray.py
b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -718,6 +718,12 @@
assert a.dot(range(5)) == 30
assert dot(range(5), range(5)) == 30
assert (dot(5, [1, 2, 3]) == [5, 10, 15]).all()
+
+ a = array([range(4), range(4, 8), range(8, 12)])
+ b = array([range(3), range(3, 6), range(6, 9), range(9, 12)])
+ c = a.dot(b)
+ assert (c == [[ 42, 48, 54], [114, 136, 158], [186, 224, 262]]).all()
+
a = array([[range(4), range(4, 8), range(8, 12)],
[range(12, 16), range(16, 20), range(20, 24)]])
raises(ValueError, "a.dot(a)")
_______________________________________________
pypy-commit mailing list
[email protected]
http://mail.python.org/mailman/listinfo/pypy-commit