Author: mattip
Branch: matrixmath-dot
Changeset: r50197:d28b98fc74ed
Date: 2011-12-05 23:17 +0200
http://bitbucket.org/pypy/pypy/changeset/d28b98fc74ed/

Log:    dot works

diff --git a/pypy/module/micronumpy/interp_numarray.py 
b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -152,7 +152,7 @@
         return arr
 
     def done(self):
-        return self.offset == self.size
+        return self.offset >= self.size
 
     def get_offset(self):
         return self.offset
@@ -197,7 +197,7 @@
 class BroadcastIterator(BaseIterator):
     '''Like a view iterator, but will repeatedly access values
        for all iterations across a res_shape, folding the offset
-       using mod() arithmetic
+       using stride = backstride = 0
     '''
     def __init__(self, arr, res_shape):
         self.indices = [0] * len(res_shape)
@@ -522,7 +522,7 @@
             assert other_critical_dim >= 0
             out_shape += self.shape[:-1] + \
                          w_other.shape[0:other_critical_dim] + \
-                         w_other.shape[other_critical_dim:]
+                         w_other.shape[other_critical_dim + 1:]
         elif len(w_other.shape) > 0:
             #dot does not reduce
             out_shape += self.shape[:-1]
@@ -535,25 +535,28 @@
         out_ndims = len(out_shape)
         #TODO: what should the order be? C or F?
         arr = W_NDimArray(out_size, out_shape, dtype=dtype)
-        out_iter = ArrayIterator(out_size)
+        out_iter = ViewIterator(arr)
         #TODO: invalidate self, w_other with arr ?
-        me_iter = BroadcastIterator(self, self.shape[:-1] + [1])
-        assert other_critical_dim >= 0
-        other_iter = BroadcastIterator(self,
-                               w_other.shape[:other_critical_dim] + [1] + \
-                                           w_other.shape[other_critical_dim:])
         while not out_iter.done():
-            w_ssd = space.newlist([space.wrap(me_iter.get_offset()), 
-                                   space.wrap(len(self.shape)-1)])
-            w_osd = space.newlist([space.wrap(other_iter.get_offset()), 
+            my_index = self.start
+            other_index = w_other.start
+            i = 0
+            while i < len(self.shape) - 1:
+                my_index += out_iter.indices[i] * self.strides[i]
+                i += 1
+            for j in range(len(w_other.shape) - 2):
+                other_index += out_iter.indices[i] * w_other.strides[j]
+            other_index += out_iter.indices[-1] * w_other.strides[-1]
+            w_ssd = space.newlist([space.wrap(my_index),
+                                   space.wrap(len(self.shape) - 1)])
+            w_osd = space.newlist([space.wrap(other_index),
                                    space.wrap(other_critical_dim)])
             w_res = self.descr_mul1d(space, w_other, w_ssd, w_osd)
+            assert isinstance(w_res, BaseArray)
             value = w_res.descr_sum(space)
-            abc=hgk
-            arr.setitem(out_iter, value)
+            arr.setitem(out_iter.get_offset(), value)
             out_iter = out_iter.next(out_ndims)
-            me_iter = me_iter.next(0)
-            other_iter = other_iter.next(0)
+            ii += 1
         return arr
 
     def get_concrete(self):
@@ -818,7 +821,8 @@
                          shape[:])
 
     def descr_mean(self, space):
-        return space.div(self.descr_sumpromote(space), 
space.wrap(self.find_size()))
+        return space.div(self.descr_sumpromote(space), 
+                         space.wrap(self.find_size()))
 
     def descr_nonzero(self, space):
         if self.find_size() > 1:
@@ -940,7 +944,7 @@
                                          shapelen=shapelen,
                                          result_size=result_size, i=i, ri=ri,
                                          self=self, result=result)
-            result.dtype.setitem(result.storage, ri.offset, self.eval(i))
+            result.dtype.setitem(result.storage, ri.get_offset(), self.eval(i))
             i = i.next(shapelen)
             ri = ri.next(shapelen)
         return result
@@ -1045,6 +1049,14 @@
         if res_shape is None:
             res_shape = self.shape  # we still force the shape on children
         #TODO: use left_start_dim, right_start_dim if they are not [-1, -1]
+        if self.left_start_dim[0] >= 0:
+            ldim = self.left_start_dim[1]
+            rdim = self.right_start_dim[1]
+            left_iter = OneDimIterator(self.left_start_dim[0],
+                        self.left.strides[ldim], self.left.shape[ldim])
+            right_iter = OneDimIterator(self.right_start_dim[0],
+                        self.right.strides[rdim], self.right.shape[rdim])
+            return Call2Iterator(left_iter, right_iter)
         return Call2Iterator(self.left.start_iter(res_shape),
                              self.right.start_iter(res_shape))
 
@@ -1143,7 +1155,7 @@
                                          self=self, source=source,
                                          res_iter=res_iter,
                                          source_iter=source_iter)
-            self.setitem(res_iter.offset, source.eval(source_iter).convert_to(
+            self.setitem(res_iter.get_offset(), 
source.eval(source_iter).convert_to(
                 self.find_dtype()))
             source_iter = source_iter.next(shapelen)
             res_iter = res_iter.next(shapelen)
@@ -1165,7 +1177,7 @@
         array = W_NDimArray(self.size, self.shape[:], self.find_dtype())
         iter = self.start_iter()
         while not iter.done():
-            array.setitem(iter.offset, self.getitem(iter.offset))
+            array.setitem(iter.get_offset(), self.getitem(iter.get_offset()))
             iter = iter.next(len(self.shape))
         return array
 
@@ -1280,7 +1292,7 @@
     arr_iter = arr.start_iter(arr.shape)
     for i in range(len(elems_w)):
         w_elem = elems_w[i]
-        dtype.setitem(arr.storage, arr_iter.offset, dtype.coerce(space, 
w_elem))
+        dtype.setitem(arr.storage, arr_iter.get_offset(), dtype.coerce(space, 
w_elem))
         arr_iter = arr_iter.next(shapelen)
     return arr
 
diff --git a/pypy/module/micronumpy/interp_ufuncs.py 
b/pypy/module/micronumpy/interp_ufuncs.py
--- a/pypy/module/micronumpy/interp_ufuncs.py
+++ b/pypy/module/micronumpy/interp_ufuncs.py
@@ -140,6 +140,7 @@
     def call(self, space, args_w):
         from pypy.module.micronumpy.interp_numarray import (Call2,
             convert_to_array, Scalar, shape_agreement)
+        #TODO: use of w_ssd, w_osd can be optimized.
         if len(args_w)<4:
             [w_lhs, w_rhs] = args_w
             w_ssd = space.newlist([space.wrap(-1)]*2)
@@ -166,9 +167,17 @@
         new_sig = signature.Signature.find_sig([
             self.signature, w_lhs.signature, w_rhs.signature
         ])
-        new_shape = shape_agreement(space, w_lhs.shape, w_rhs.shape)
+        new_shape = []
+        ssd = [space.int_w(s) for s in space.listview(w_ssd)]
+        osd = [space.int_w(s) for s in space.listview(w_osd)]
+        if  ssd[0]<0:
+            new_shape = shape_agreement(space, w_lhs.shape, w_rhs.shape)
+        else:
+            #Assumption (should have been checked in call): 
+            #w_lhs.shape[ssd[1]] == w_rhs.shape[osd[1]]
+            new_shape = [w_lhs.shape[ssd[1]]]
         w_res = Call2(new_sig, new_shape, calc_dtype,
-                      res_dtype, w_lhs, w_rhs, w_ssd, w_osd)
+                      res_dtype, w_lhs, w_rhs, ssd, osd)
         w_lhs.add_invalidates(w_res)
         w_rhs.add_invalidates(w_res)
         return w_res
diff --git a/pypy/module/micronumpy/test/test_numarray.py 
b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -718,6 +718,12 @@
         assert a.dot(range(5)) == 30
         assert dot(range(5), range(5)) == 30
         assert (dot(5, [1, 2, 3]) == [5, 10, 15]).all()
+
+        a = array([range(4), range(4, 8), range(8, 12)])
+        b = array([range(3), range(3, 6), range(6, 9), range(9, 12)])
+        c = a.dot(b)
+        assert (c == [[ 42, 48, 54], [114, 136, 158], [186, 224, 262]]).all()
+
         a = array([[range(4), range(4, 8), range(8, 12)],
                    [range(12, 16), range(16, 20), range(20, 24)]])
         raises(ValueError, "a.dot(a)")
_______________________________________________
pypy-commit mailing list
pypy-commit@python.org
http://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to