Author: mattip
Branch: matrixmath-dot
Changeset: r50193:beba9400c9dd
Date: 2011-12-05 15:41 +0200
http://bitbucket.org/pypy/pypy/changeset/beba9400c9dd/

Log:    add bin_impl_one_dim, would be nice to have some tests

diff --git a/pypy/module/micronumpy/interp_numarray.py 
b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -403,6 +403,7 @@
             greens=['shapelen', 'signature'],
             reds=['result', 'idx', 'i', 'self', 'cur_best', 'dtype']
         )
+
         def loop(self):
             i = self.start_iter()
             cur_best = self.eval(i)
@@ -424,6 +425,7 @@
                 i = i.next(shapelen)
                 idx += 1
             return result
+
         def impl(self, space):
             size = self.find_size()
             if size == 0:
@@ -444,6 +446,7 @@
                 return False
             i = i.next(shapelen)
         return True
+
     def descr_all(self, space):
         return space.wrap(self._all())
 
@@ -459,22 +462,39 @@
                 return True
             i = i.next(shapelen)
         return False
+
     def descr_any(self, space):
         return space.wrap(self._any())
 
     descr_argmax = _reduce_argmax_argmin_impl("max")
     descr_argmin = _reduce_argmax_argmin_impl("min")
 
+    def _binop_impl_one_dim(ufunc_name):
+        #The third and fourth arguments allow the operator to proceed on a
+        #single dimension starting at a particular index
+        #i.e. ssd => self start, dimension; osd => other start, dimension
+        def impl(self, space, w_other, w_ssd, w_osd):
+            return getattr(interp_ufuncs.get(space), ufunc_name).call(space,
+                                         [self, w_other, w_ssd, w_osd])
+        return func_with_new_name(impl, "binop_%s_impl" % ufunc_name)
+
+    descr_add1d = _binop_impl_one_dim("add")
+    descr_sub1d = _binop_impl_one_dim("subtract")
+    descr_mul1d = _binop_impl_one_dim("multiply")
+    descr_div1d = _binop_impl_one_dim("divide")
+    descr_pow1d = _binop_impl_one_dim("power")
+    descr_mod1d = _binop_impl_one_dim("mod")
+
     def descr_dot(self, space, w_other):
         '''Dot product of two arrays.
-    
+
     For 2-D arrays it is equivalent to matrix multiplication, and for 1-D
     arrays to inner product of vectors (without complex conjugation). For
     N dimensions it is a sum product over the last axis of `a` and
     the second-to-last of `b`::
-    
+
         dot(a, b)[i,j,k,m] = sum(a[i,j,:] * b[k,:,m])'''
-        #numpy's doc string :) 
+        #numpy's doc string :)
         w_other = convert_to_array(space, w_other)
         if isinstance(w_other, Scalar):
             return self.descr_mul(space, w_other)
@@ -482,23 +502,25 @@
             w_res = self.descr_mul(space, w_other)
             assert isinstance(w_res, BaseArray)
             return w_res.descr_sum(space)
-        dtype = interp_ufuncs.find_binop_result_dtype(space, 
+        dtype = interp_ufuncs.find_binop_result_dtype(space,
                                      self.find_dtype(), w_other.find_dtype())
-        if self.find_size() < 1 and w_other.find_size() <1:
+        if self.find_size() < 1 and w_other.find_size() < 1:
             #numpy compatability
-            return scalar_w(space, dtype,space.wrap(0))
+            return scalar_w(space, dtype, space.wrap(0))
         #Do the dims match?
         my_critical_dim_size = self.shape[-1]
         other_critical_dim_size = w_other.shape[0]
         other_critical_dim = 0
-        other_critical_dim_stride = w_other.strides[0] 
+        other_critical_dim_stride = w_other.strides[0]
         out_shape = []
         if len(w_other.shape) > 1:
-            other_critical_dim = len(w_other.shape)-1
+            other_critical_dim = len(w_other.shape) - 1
             other_critical_dim_size = w_other.shape[other_critical_dim]
             other_critical_dim_stride = w_other.strides[other_critical_dim]
             assert other_critical_dim >= 0
-            out_shape += self.shape[:-1] + w_other.shape[0:other_critical_dim] 
+ w_other.shape[other_critical_dim:]
+            out_shape += self.shape[:-1] + \
+                         w_other.shape[0:other_critical_dim] + \
+                         w_other.shape[other_critical_dim:]
         elif len(w_other.shape) > 0:
             #dot does not reduce
             out_shape += self.shape[:-1]
@@ -513,15 +535,16 @@
         arr = W_NDimArray(out_size, out_shape, dtype=dtype)
         out_iter = ArrayIterator(out_size)
         #TODO: invalidate self, w_other with arr ?
-        
-        me_iter = BroadcastIterator(self,self.shape[:-1] + [1])
+        me_iter = BroadcastIterator(self, self.shape[:-1] + [1])
         assert other_critical_dim >= 0
-        other_iter = BroadcastIterator(self, 
+        other_iter = BroadcastIterator(self,
                                w_other.shape[:other_critical_dim] + [1] + \
                                            w_other.shape[other_critical_dim:])
         while not out_iter.done():
-            i = OneDimIterator(me_iter.get_offset(), self.strides[-1], 
self.shape[-1])
-            j = OneDimIterator(other_iter.get_offset(), 
other_critical_dim_stride, other_critical_dim_size)
+            i = OneDimIterator(me_iter.get_offset(),
+                          self.strides[-1], self.shape[-1])
+            j = OneDimIterator(other_iter.get_offset(),
+                          other_critical_dim_stride, other_critical_dim_size)
             #Heres what I would like to do, but how?
             #value = sum(mult_with_iters(self, i, w_other, j))
             #arr.setitem(out_iter, value)
@@ -529,7 +552,6 @@
             me_iter = me_iter.next(0)
             other_iter = other_iter.next(0)
         return arr
-           
 
     def get_concrete(self):
         raise NotImplementedError
@@ -898,7 +920,8 @@
         self.res_dtype = res_dtype
 
     def _del_sources(self):
-        # Function for deleting references to source arrays, to allow 
garbage-collecting them
+        # Function for deleting references to source arrays,
+        #to allow garbage-collecting them
         raise NotImplementedError
 
     def compute(self):
@@ -993,11 +1016,14 @@
     """
     Intermediate class for performing binary operations.
     """
-    def __init__(self, signature, shape, calc_dtype, res_dtype, left, right):
+    def __init__(self, signature, shape, calc_dtype, res_dtype, left, right,
+             left_start_dim=[-1, -1], right_start_dim=[-1, -1]):
         # XXX do something if left.order != right.order
         VirtualArray.__init__(self, signature, shape, res_dtype, left.order)
         self.left = left
         self.right = right
+        self.left_start_dim = left_start_dim
+        self.right_start_dim = right_start_dim
         self.calc_dtype = calc_dtype
         self.size = 1
         for s in self.shape:
@@ -1015,6 +1041,7 @@
             return self.forced_result.start_iter(res_shape)
         if res_shape is None:
             res_shape = self.shape  # we still force the shape on children
+        #TODO: use left_start_dim, right_start_dim if they are not [-1, -1]
         return Call2Iterator(self.left.start_iter(res_shape),
                              self.right.start_iter(res_shape))
 
diff --git a/pypy/module/micronumpy/interp_ufuncs.py 
b/pypy/module/micronumpy/interp_ufuncs.py
--- a/pypy/module/micronumpy/interp_ufuncs.py
+++ b/pypy/module/micronumpy/interp_ufuncs.py
@@ -139,8 +139,12 @@
     def call(self, space, args_w):
         from pypy.module.micronumpy.interp_numarray import (Call2,
             convert_to_array, Scalar, shape_agreement)
-
-        [w_lhs, w_rhs] = args_w
+        if len(args_w)<4:
+            [w_lhs, w_rhs] = args_w
+            w_ssd = space.newlist([space.wrap(-1)]*2)
+            w_osd = space.newlist([space.wrap(-1)]*2)
+        else:
+            [w_lhs, w_rhs, w_ssd, w_osd] = args_w
         w_lhs = convert_to_array(space, w_lhs)
         w_rhs = convert_to_array(space, w_rhs)
         calc_dtype = find_binop_result_dtype(space,
@@ -163,7 +167,7 @@
         ])
         new_shape = shape_agreement(space, w_lhs.shape, w_rhs.shape)
         w_res = Call2(new_sig, new_shape, calc_dtype,
-                      res_dtype, w_lhs, w_rhs)
+                      res_dtype, w_lhs, w_rhs, w_ssd, w_osd)
         w_lhs.add_invalidates(w_res)
         w_rhs.add_invalidates(w_res)
         return w_res
diff --git a/pypy/module/micronumpy/test/test_numarray.py 
b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -716,14 +716,14 @@
         assert dot(range(5), range(5)) == 30
         assert (dot(5, [1, 2, 3]) == [5, 10, 15]).all()
         a = array([[range(4), range(4, 8), range(8, 12)],
-                   [range(12, 16),range(16, 20),range(20, 24)]])
-        raises(ValueError,"a.dot(a)")
+                   [range(12, 16), range(16, 20), range(20, 24)]])
+        raises(ValueError, "a.dot(a)")
         b = a[0, :, :].T
         #Superfluous shape test makes the intention of the test clearer
-        assert a.shape == (2, 3, 4) 
+        assert a.shape == (2, 3, 4)
         assert b.shape == (4, 3)
         c = a.dot(b)
-        assert (c == [[[14, 38,62], [38, 126, 214], [62, 214, 366]], 
+        assert (c == [[[14, 38, 62], [38, 126, 214], [62, 214, 366]],
                    [[86, 302, 518], [110, 390, 670], [134, 478, 822]]]).all()
 
     def test_dot_constant(self):
_______________________________________________
pypy-commit mailing list
pypy-commit@python.org
http://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to