Author: Maciej Fijalkowski <[email protected]>
Branch: 
Changeset: r51970:97a32a230bff
Date: 2012-01-30 22:19 +0200
http://bitbucket.org/pypy/pypy/changeset/97a32a230bff/

Log:    (mattip, fijal reviewing) merge matrixmath-dot, this adds 1-d and
        2-d dot operations that should work. I did not check the actual
        numbers ;-)

diff --git a/pypy/module/micronumpy/compile.py 
b/pypy/module/micronumpy/compile.py
--- a/pypy/module/micronumpy/compile.py
+++ b/pypy/module/micronumpy/compile.py
@@ -34,7 +34,7 @@
 
 SINGLE_ARG_FUNCTIONS = ["sum", "prod", "max", "min", "all", "any",
                         "unegative", "flat"]
-TWO_ARG_FUNCTIONS = ['take']
+TWO_ARG_FUNCTIONS = ["dot", 'take']
 
 class FakeSpace(object):
     w_ValueError = None
@@ -410,13 +410,19 @@
             else:
                 assert False # unreachable code
         elif self.name in TWO_ARG_FUNCTIONS:
+            if len(self.args) != 2:
+                raise ArgumentMismatch
             arg = self.args[1].execute(interp)
             if not isinstance(arg, BaseArray):
                 raise ArgumentNotAnArray
-            if self.name == 'take':
+            if not isinstance(arg, BaseArray):
+                raise ArgumentNotAnArray
+            if self.name == "dot":
+                w_res = arr.descr_dot(interp.space, arg)
+            elif self.name == 'take':
                 w_res = arr.descr_take(interp.space, arg)
             else:
-                assert False # unreachable
+                assert False # unreachable code
         else:
             raise WrongFunctionName
         if isinstance(w_res, BaseArray):
diff --git a/pypy/module/micronumpy/dot.py b/pypy/module/micronumpy/dot.py
new file mode 100644
--- /dev/null
+++ b/pypy/module/micronumpy/dot.py
@@ -0,0 +1,85 @@
+from pypy.module.micronumpy.strides import calculate_dot_strides
+from pypy.interpreter.error import OperationError
+from pypy.module.micronumpy.interp_iter import ViewIterator
+from pypy.rlib import jit
+
+def dot_printable_location(shapelen):
+    return 'numpy dot [%d]' % shapelen
+
+dot_driver = jit.JitDriver(
+    greens=['shapelen'],
+    reds=['lefti', 'righti', 'outi', 'result', 'right', 'dtype',
+          'left'],
+    get_printable_location=dot_printable_location,
+    name='dot',
+)
+
+def match_dot_shapes(space, left, right):
+    my_critical_dim_size = left.shape[-1]
+    right_critical_dim_size = right.shape[0]
+    right_critical_dim = 0
+    out_shape = []
+    if len(right.shape) > 1:
+        right_critical_dim = len(right.shape) - 2
+        right_critical_dim_size = right.shape[right_critical_dim]
+        assert right_critical_dim >= 0
+        out_shape += left.shape[:-1] + \
+                     right.shape[0:right_critical_dim] + \
+                     right.shape[right_critical_dim + 1:]
+    elif len(right.shape) > 0:
+        #dot does not reduce for scalars
+        out_shape += left.shape[:-1]
+    if my_critical_dim_size != right_critical_dim_size:
+        raise OperationError(space.w_ValueError, space.wrap(
+                                        "objects are not aligned"))
+    return out_shape, right_critical_dim
+
+def multidim_dot(space, left, right, result, dtype, right_critical_dim):
+    ''' assumes left, right are concrete arrays
+    given left.shape == [3, 5, 7],
+          right.shape == [2, 7, 4]
+    then
+     result.shape == [3, 5, 2, 4]
+     broadcast shape should be [3, 5, 2, 7, 4]
+     result should skip dims 3 which is len(result_shape) - 1
+        (note that if right is 1d, result should 
+                  skip len(result_shape))
+     left should skip 2, 4 which is a.ndims-1 + range(right.ndims)
+          except where it==(right.ndims-2)
+     right should skip 0, 1
+    '''
+    broadcast_shape = left.shape[:-1] + right.shape
+    shapelen = len(broadcast_shape)
+    left_skip = [len(left.shape) - 1 + i for i in range(len(right.shape))
+                                         if i != right_critical_dim]
+    right_skip = range(len(left.shape) - 1)
+    result_skip = [len(result.shape) - (len(right.shape) > 1)]
+    _r = calculate_dot_strides(result.strides, result.backstrides,
+                                  broadcast_shape, result_skip)
+    outi = ViewIterator(result.start, _r[0], _r[1], broadcast_shape)
+    _r = calculate_dot_strides(left.strides, left.backstrides,
+                                  broadcast_shape, left_skip)
+    lefti = ViewIterator(left.start, _r[0], _r[1], broadcast_shape)
+    _r = calculate_dot_strides(right.strides, right.backstrides,
+                                  broadcast_shape, right_skip)
+    righti = ViewIterator(right.start, _r[0], _r[1], broadcast_shape)
+    while not outi.done():
+        dot_driver.jit_merge_point(left=left,
+                                   right=right,
+                                   shapelen=shapelen,
+                                   lefti=lefti,
+                                   righti=righti,
+                                   outi=outi,
+                                   result=result,
+                                   dtype=dtype,
+                                  )
+        lval = left.getitem(lefti.offset).convert_to(dtype) 
+        rval = right.getitem(righti.offset).convert_to(dtype) 
+        outval = result.getitem(outi.offset).convert_to(dtype) 
+        v = dtype.itemtype.mul(lval, rval)
+        value = dtype.itemtype.add(v, outval).convert_to(dtype)
+        result.setitem(outi.offset, value)
+        outi = outi.next(shapelen)
+        righti = righti.next(shapelen)
+        lefti = lefti.next(shapelen)
+    return result
diff --git a/pypy/module/micronumpy/interp_iter.py 
b/pypy/module/micronumpy/interp_iter.py
--- a/pypy/module/micronumpy/interp_iter.py
+++ b/pypy/module/micronumpy/interp_iter.py
@@ -2,7 +2,7 @@
 from pypy.rlib import jit
 from pypy.rlib.objectmodel import instantiate
 from pypy.module.micronumpy.strides import calculate_broadcast_strides,\
-     calculate_slice_strides
+     calculate_slice_strides, calculate_dot_strides
 
 """ This is a mini-tutorial on iterators, strides, and
 memory layout. It assumes you are familiar with the terms, see
@@ -76,6 +76,7 @@
     def __init__(self, res_shape):
         self.res_shape = res_shape
 
+
 class BaseIterator(object):
     def next(self, shapelen):
         raise NotImplementedError
@@ -227,6 +228,7 @@
     def transform(self, arr, t):
         pass
 
+
 class AxisIterator(BaseIterator):
     def __init__(self, start, dim, shape, strides, backstrides):
         self.res_shape = shape[:]
diff --git a/pypy/module/micronumpy/interp_numarray.py 
b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -7,6 +7,7 @@
 from pypy.module.micronumpy.strides import (calculate_slice_strides,
     shape_agreement, find_shape_and_elems, get_shape_from_iterable,
     calc_new_strides, to_coords)
+from dot import multidim_dot, match_dot_shapes
 from pypy.rlib import jit
 from pypy.rpython.lltypesystem import lltype, rffi
 from pypy.tool.sourcetools import func_with_new_name
@@ -203,6 +204,7 @@
                 frame.next(shapelen)
                 idx += 1
             return result
+
         def impl(self, space):
             if self.size == 0:
                 raise OperationError(space.w_ValueError,
@@ -248,13 +250,26 @@
     descr_argmin = _reduce_argmax_argmin_impl("min")
 
     def descr_dot(self, space, w_other):
-        w_other = convert_to_array(space, w_other)
-        if isinstance(w_other, Scalar):
-            return self.descr_mul(space, w_other)
-        else:
-            w_res = self.descr_mul(space, w_other)
+        other = convert_to_array(space, w_other)
+        if isinstance(other, Scalar):
+            return self.descr_mul(space, other)
+        elif len(self.shape) < 2 and len(other.shape) < 2:
+            w_res = self.descr_mul(space, other)
             assert isinstance(w_res, BaseArray)
             return w_res.descr_sum(space, space.wrap(-1))
+        dtype = interp_ufuncs.find_binop_result_dtype(space,
+                                     self.find_dtype(), other.find_dtype())
+        if self.size < 1 and other.size < 1:
+            # numpy compatability
+            return scalar_w(space, dtype, space.wrap(0))
+        # Do the dims match?
+        out_shape, other_critical_dim = match_dot_shapes(space, self, other)
+        out_size = support.product(out_shape)
+        result = W_NDimArray(out_size, out_shape, dtype)
+        # This is the place to add fpypy and blas
+        return multidim_dot(space, self.get_concrete(), 
+                            other.get_concrete(), result, dtype,
+                            other_critical_dim)
 
     def get_concrete(self):
         raise NotImplementedError
@@ -1186,6 +1201,8 @@
     return space.wrap(s)
 
 def dot(space, w_obj, w_obj2):
+    '''see numpypy.dot. Does not exist as an ndarray method in numpy.
+    '''
     w_arr = convert_to_array(space, w_obj)
     if isinstance(w_arr, Scalar):
         return convert_to_array(space, w_obj2).descr_dot(space, w_arr)
diff --git a/pypy/module/micronumpy/interp_ufuncs.py 
b/pypy/module/micronumpy/interp_ufuncs.py
--- a/pypy/module/micronumpy/interp_ufuncs.py
+++ b/pypy/module/micronumpy/interp_ufuncs.py
@@ -216,17 +216,17 @@
                                               sig=sig,
                                               identity=identity,
                                               shapelen=shapelen, arr=arr)
-            iter = frame.get_final_iter()
+            iterator = frame.get_final_iter()
             v = sig.eval(frame, arr).convert_to(sig.calc_dtype)
-            if iter.first_line:
+            if iterator.first_line:
                 if identity is not None:
                     value = self.func(sig.calc_dtype, identity, v)
                 else:
                     value = v
             else:
-                cur = arr.left.getitem(iter.offset)
+                cur = arr.left.getitem(iterator.offset)
                 value = self.func(sig.calc_dtype, cur, v)
-            arr.left.setitem(iter.offset, value)
+            arr.left.setitem(iterator.offset, value)
             frame.next(shapelen)
 
     def reduce_loop(self, shapelen, sig, frame, value, obj, dtype):
diff --git a/pypy/module/micronumpy/signature.py 
b/pypy/module/micronumpy/signature.py
--- a/pypy/module/micronumpy/signature.py
+++ b/pypy/module/micronumpy/signature.py
@@ -353,7 +353,6 @@
         assert isinstance(arr, Call2)
         lhs = self.left.eval(frame, arr.left).convert_to(self.calc_dtype)
         rhs = self.right.eval(frame, arr.right).convert_to(self.calc_dtype)
-        
         return self.binfunc(self.calc_dtype, lhs, rhs)
 
     def debug_repr(self):
diff --git a/pypy/module/micronumpy/strides.py 
b/pypy/module/micronumpy/strides.py
--- a/pypy/module/micronumpy/strides.py
+++ b/pypy/module/micronumpy/strides.py
@@ -230,3 +230,18 @@
                     n_old_elems_to_use *= old_shape[oldI]
     assert len(new_strides) == len(new_shape)
     return new_strides
+
+
+def calculate_dot_strides(strides, backstrides, res_shape, skip_dims):
+    rstrides = []
+    rbackstrides = []
+    j=0
+    for i in range(len(res_shape)):
+        if i in skip_dims:
+            rstrides.append(0)
+            rbackstrides.append(0)
+        else:
+            rstrides.append(strides[j])
+            rbackstrides.append(backstrides[j])
+            j += 1
+    return rstrides, rbackstrides
diff --git a/pypy/module/micronumpy/test/test_compile.py 
b/pypy/module/micronumpy/test/test_compile.py
--- a/pypy/module/micronumpy/test/test_compile.py
+++ b/pypy/module/micronumpy/test/test_compile.py
@@ -246,6 +246,15 @@
         """)
         assert interp.results[0].value == 11
 
+    def test_dot(self):
+        interp = self.run("""
+        a = [[1, 2], [3, 4]]
+        b = [[5, 6], [7, 8]]
+        c = dot(a, b)
+        c -> 0 -> 0
+        """)
+        assert interp.results[0].value == 19
+
     def test_flat_iter(self):
         interp = self.run('''
         a = |30|
diff --git a/pypy/module/micronumpy/test/test_numarray.py 
b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -559,6 +559,7 @@
         b = a * a
         for i in range(5):
             assert b[i] == i * i
+        assert b.dtype is a.dtype
 
         a = _numpypy.array(range(5), dtype=bool)
         b = a * a
@@ -784,8 +785,8 @@
     def test_sum(self):
         from _numpypy import array
         a = array(range(5))
-        assert a.sum() == 10.0
-        assert a[:4].sum() == 6.0
+        assert a.sum() == 10
+        assert a[:4].sum() == 6
 
         a = array([True] * 5, bool)
         assert a.sum() == 5
@@ -910,15 +911,36 @@
         assert c.any() == False
 
     def test_dot(self):
-        from _numpypy import array, dot
+        from _numpypy import array, dot, arange
         a = array(range(5))
-        assert a.dot(a) == 30.0
+        assert dot(a, a) == 30.0
 
         a = array(range(5))
         assert a.dot(range(5)) == 30
         assert dot(range(5), range(5)) == 30
         assert (dot(5, [1, 2, 3]) == [5, 10, 15]).all()
 
+        a = arange(12).reshape(3, 4)
+        b = arange(12).reshape(4, 3)
+        c = a.dot(b)
+        assert (c == [[ 42, 48, 54], [114, 136, 158], [186, 224, 262]]).all()
+
+        a = arange(24).reshape(2, 3, 4)
+        raises(ValueError, "a.dot(a)")
+        b = a[0, :, :].T
+        #Superfluous shape test makes the intention of the test clearer
+        assert a.shape == (2, 3, 4)
+        assert b.shape == (4, 3)
+        c = dot(a, b)
+        assert (c == [[[14, 38, 62], [38, 126, 214], [62, 214, 366]],
+                   [[86, 302, 518], [110, 390, 670], [134, 478, 822]]]).all()
+        c = dot(a, b[:, 2])
+        assert (c == [[62, 214, 366], [518, 670, 822]]).all()
+        a = arange(3*4*5*6).reshape((3,4,5,6))
+        b = arange(3*4*5*6)[::-1].reshape((5,4,6,3))
+        assert dot(a, b)[2,3,2,1,2,2] == 499128
+        assert sum(a[2,3,2,:] * b[1,2,:,2]) == 499128
+
     def test_dot_constant(self):
         from _numpypy import array
         a = array(range(5))
diff --git a/pypy/module/micronumpy/test/test_zjit.py 
b/pypy/module/micronumpy/test/test_zjit.py
--- a/pypy/module/micronumpy/test/test_zjit.py
+++ b/pypy/module/micronumpy/test/test_zjit.py
@@ -456,6 +456,30 @@
                                 'int_rshift': 1,
                                 })
 
+    def define_dot():
+        return """
+        a = [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]
+        b=[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]]
+        c = dot(a, b)
+        c -> 1 -> 2
+        """
+
+    def test_dot(self):
+        result = self.run("dot")
+        assert result == 184
+        self.check_simple_loop({'arraylen_gc': 9,
+                                'float_add': 1,
+                                'float_mul': 1,
+                                'getinteriorfield_raw': 3,
+                                'guard_false': 3,
+                                'guard_true': 3,
+                                'int_add': 6,
+                                'int_lt': 6,
+                                'int_sub': 3,
+                                'jump': 1,
+                                'setinteriorfield_raw': 1})
+
+
 class TestNumpyOld(LLJitMixin):
     def setup_class(cls):
         py.test.skip("old")
_______________________________________________
pypy-commit mailing list
[email protected]
http://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to