Author: Maciej Fijalkowski <[email protected]>
Branch:
Changeset: r51970:97a32a230bff
Date: 2012-01-30 22:19 +0200
http://bitbucket.org/pypy/pypy/changeset/97a32a230bff/
Log: (mattip, fijal reviewing) merge matrixmath-dot, this adds 1-d and
2-d dot operations that should work. I did not check the actual
numbers ;-)
diff --git a/pypy/module/micronumpy/compile.py
b/pypy/module/micronumpy/compile.py
--- a/pypy/module/micronumpy/compile.py
+++ b/pypy/module/micronumpy/compile.py
@@ -34,7 +34,7 @@
SINGLE_ARG_FUNCTIONS = ["sum", "prod", "max", "min", "all", "any",
"unegative", "flat"]
-TWO_ARG_FUNCTIONS = ['take']
+TWO_ARG_FUNCTIONS = ["dot", 'take']
class FakeSpace(object):
w_ValueError = None
@@ -410,13 +410,19 @@
else:
assert False # unreachable code
elif self.name in TWO_ARG_FUNCTIONS:
+ if len(self.args) != 2:
+ raise ArgumentMismatch
arg = self.args[1].execute(interp)
if not isinstance(arg, BaseArray):
raise ArgumentNotAnArray
- if self.name == 'take':
+ if not isinstance(arg, BaseArray):
+ raise ArgumentNotAnArray
+ if self.name == "dot":
+ w_res = arr.descr_dot(interp.space, arg)
+ elif self.name == 'take':
w_res = arr.descr_take(interp.space, arg)
else:
- assert False # unreachable
+ assert False # unreachable code
else:
raise WrongFunctionName
if isinstance(w_res, BaseArray):
diff --git a/pypy/module/micronumpy/dot.py b/pypy/module/micronumpy/dot.py
new file mode 100644
--- /dev/null
+++ b/pypy/module/micronumpy/dot.py
@@ -0,0 +1,85 @@
+from pypy.module.micronumpy.strides import calculate_dot_strides
+from pypy.interpreter.error import OperationError
+from pypy.module.micronumpy.interp_iter import ViewIterator
+from pypy.rlib import jit
+
+def dot_printable_location(shapelen):
+ return 'numpy dot [%d]' % shapelen
+
+dot_driver = jit.JitDriver(
+ greens=['shapelen'],
+ reds=['lefti', 'righti', 'outi', 'result', 'right', 'dtype',
+ 'left'],
+ get_printable_location=dot_printable_location,
+ name='dot',
+)
+
+def match_dot_shapes(space, left, right):
+ my_critical_dim_size = left.shape[-1]
+ right_critical_dim_size = right.shape[0]
+ right_critical_dim = 0
+ out_shape = []
+ if len(right.shape) > 1:
+ right_critical_dim = len(right.shape) - 2
+ right_critical_dim_size = right.shape[right_critical_dim]
+ assert right_critical_dim >= 0
+ out_shape += left.shape[:-1] + \
+ right.shape[0:right_critical_dim] + \
+ right.shape[right_critical_dim + 1:]
+ elif len(right.shape) > 0:
+ #dot does not reduce for scalars
+ out_shape += left.shape[:-1]
+ if my_critical_dim_size != right_critical_dim_size:
+ raise OperationError(space.w_ValueError, space.wrap(
+ "objects are not aligned"))
+ return out_shape, right_critical_dim
+
+def multidim_dot(space, left, right, result, dtype, right_critical_dim):
+ ''' assumes left, right are concrete arrays
+ given left.shape == [3, 5, 7],
+ right.shape == [2, 7, 4]
+ then
+ result.shape == [3, 5, 2, 4]
+ broadcast shape should be [3, 5, 2, 7, 4]
+ result should skip dims 3 which is len(result_shape) - 1
+ (note that if right is 1d, result should
+ skip len(result_shape))
+ left should skip 2, 4 which is a.ndims-1 + range(right.ndims)
+ except where it==(right.ndims-2)
+ right should skip 0, 1
+ '''
+ broadcast_shape = left.shape[:-1] + right.shape
+ shapelen = len(broadcast_shape)
+ left_skip = [len(left.shape) - 1 + i for i in range(len(right.shape))
+ if i != right_critical_dim]
+ right_skip = range(len(left.shape) - 1)
+ result_skip = [len(result.shape) - (len(right.shape) > 1)]
+ _r = calculate_dot_strides(result.strides, result.backstrides,
+ broadcast_shape, result_skip)
+ outi = ViewIterator(result.start, _r[0], _r[1], broadcast_shape)
+ _r = calculate_dot_strides(left.strides, left.backstrides,
+ broadcast_shape, left_skip)
+ lefti = ViewIterator(left.start, _r[0], _r[1], broadcast_shape)
+ _r = calculate_dot_strides(right.strides, right.backstrides,
+ broadcast_shape, right_skip)
+ righti = ViewIterator(right.start, _r[0], _r[1], broadcast_shape)
+ while not outi.done():
+ dot_driver.jit_merge_point(left=left,
+ right=right,
+ shapelen=shapelen,
+ lefti=lefti,
+ righti=righti,
+ outi=outi,
+ result=result,
+ dtype=dtype,
+ )
+ lval = left.getitem(lefti.offset).convert_to(dtype)
+ rval = right.getitem(righti.offset).convert_to(dtype)
+ outval = result.getitem(outi.offset).convert_to(dtype)
+ v = dtype.itemtype.mul(lval, rval)
+ value = dtype.itemtype.add(v, outval).convert_to(dtype)
+ result.setitem(outi.offset, value)
+ outi = outi.next(shapelen)
+ righti = righti.next(shapelen)
+ lefti = lefti.next(shapelen)
+ return result
diff --git a/pypy/module/micronumpy/interp_iter.py
b/pypy/module/micronumpy/interp_iter.py
--- a/pypy/module/micronumpy/interp_iter.py
+++ b/pypy/module/micronumpy/interp_iter.py
@@ -2,7 +2,7 @@
from pypy.rlib import jit
from pypy.rlib.objectmodel import instantiate
from pypy.module.micronumpy.strides import calculate_broadcast_strides,\
- calculate_slice_strides
+ calculate_slice_strides, calculate_dot_strides
""" This is a mini-tutorial on iterators, strides, and
memory layout. It assumes you are familiar with the terms, see
@@ -76,6 +76,7 @@
def __init__(self, res_shape):
self.res_shape = res_shape
+
class BaseIterator(object):
def next(self, shapelen):
raise NotImplementedError
@@ -227,6 +228,7 @@
def transform(self, arr, t):
pass
+
class AxisIterator(BaseIterator):
def __init__(self, start, dim, shape, strides, backstrides):
self.res_shape = shape[:]
diff --git a/pypy/module/micronumpy/interp_numarray.py
b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -7,6 +7,7 @@
from pypy.module.micronumpy.strides import (calculate_slice_strides,
shape_agreement, find_shape_and_elems, get_shape_from_iterable,
calc_new_strides, to_coords)
+from dot import multidim_dot, match_dot_shapes
from pypy.rlib import jit
from pypy.rpython.lltypesystem import lltype, rffi
from pypy.tool.sourcetools import func_with_new_name
@@ -203,6 +204,7 @@
frame.next(shapelen)
idx += 1
return result
+
def impl(self, space):
if self.size == 0:
raise OperationError(space.w_ValueError,
@@ -248,13 +250,26 @@
descr_argmin = _reduce_argmax_argmin_impl("min")
def descr_dot(self, space, w_other):
- w_other = convert_to_array(space, w_other)
- if isinstance(w_other, Scalar):
- return self.descr_mul(space, w_other)
- else:
- w_res = self.descr_mul(space, w_other)
+ other = convert_to_array(space, w_other)
+ if isinstance(other, Scalar):
+ return self.descr_mul(space, other)
+ elif len(self.shape) < 2 and len(other.shape) < 2:
+ w_res = self.descr_mul(space, other)
assert isinstance(w_res, BaseArray)
return w_res.descr_sum(space, space.wrap(-1))
+ dtype = interp_ufuncs.find_binop_result_dtype(space,
+ self.find_dtype(), other.find_dtype())
+ if self.size < 1 and other.size < 1:
+ # numpy compatability
+ return scalar_w(space, dtype, space.wrap(0))
+ # Do the dims match?
+ out_shape, other_critical_dim = match_dot_shapes(space, self, other)
+ out_size = support.product(out_shape)
+ result = W_NDimArray(out_size, out_shape, dtype)
+ # This is the place to add fpypy and blas
+ return multidim_dot(space, self.get_concrete(),
+ other.get_concrete(), result, dtype,
+ other_critical_dim)
def get_concrete(self):
raise NotImplementedError
@@ -1186,6 +1201,8 @@
return space.wrap(s)
def dot(space, w_obj, w_obj2):
+ '''see numpypy.dot. Does not exist as an ndarray method in numpy.
+ '''
w_arr = convert_to_array(space, w_obj)
if isinstance(w_arr, Scalar):
return convert_to_array(space, w_obj2).descr_dot(space, w_arr)
diff --git a/pypy/module/micronumpy/interp_ufuncs.py
b/pypy/module/micronumpy/interp_ufuncs.py
--- a/pypy/module/micronumpy/interp_ufuncs.py
+++ b/pypy/module/micronumpy/interp_ufuncs.py
@@ -216,17 +216,17 @@
sig=sig,
identity=identity,
shapelen=shapelen, arr=arr)
- iter = frame.get_final_iter()
+ iterator = frame.get_final_iter()
v = sig.eval(frame, arr).convert_to(sig.calc_dtype)
- if iter.first_line:
+ if iterator.first_line:
if identity is not None:
value = self.func(sig.calc_dtype, identity, v)
else:
value = v
else:
- cur = arr.left.getitem(iter.offset)
+ cur = arr.left.getitem(iterator.offset)
value = self.func(sig.calc_dtype, cur, v)
- arr.left.setitem(iter.offset, value)
+ arr.left.setitem(iterator.offset, value)
frame.next(shapelen)
def reduce_loop(self, shapelen, sig, frame, value, obj, dtype):
diff --git a/pypy/module/micronumpy/signature.py
b/pypy/module/micronumpy/signature.py
--- a/pypy/module/micronumpy/signature.py
+++ b/pypy/module/micronumpy/signature.py
@@ -353,7 +353,6 @@
assert isinstance(arr, Call2)
lhs = self.left.eval(frame, arr.left).convert_to(self.calc_dtype)
rhs = self.right.eval(frame, arr.right).convert_to(self.calc_dtype)
-
return self.binfunc(self.calc_dtype, lhs, rhs)
def debug_repr(self):
diff --git a/pypy/module/micronumpy/strides.py
b/pypy/module/micronumpy/strides.py
--- a/pypy/module/micronumpy/strides.py
+++ b/pypy/module/micronumpy/strides.py
@@ -230,3 +230,18 @@
n_old_elems_to_use *= old_shape[oldI]
assert len(new_strides) == len(new_shape)
return new_strides
+
+
+def calculate_dot_strides(strides, backstrides, res_shape, skip_dims):
+ rstrides = []
+ rbackstrides = []
+ j=0
+ for i in range(len(res_shape)):
+ if i in skip_dims:
+ rstrides.append(0)
+ rbackstrides.append(0)
+ else:
+ rstrides.append(strides[j])
+ rbackstrides.append(backstrides[j])
+ j += 1
+ return rstrides, rbackstrides
diff --git a/pypy/module/micronumpy/test/test_compile.py
b/pypy/module/micronumpy/test/test_compile.py
--- a/pypy/module/micronumpy/test/test_compile.py
+++ b/pypy/module/micronumpy/test/test_compile.py
@@ -246,6 +246,15 @@
""")
assert interp.results[0].value == 11
+ def test_dot(self):
+ interp = self.run("""
+ a = [[1, 2], [3, 4]]
+ b = [[5, 6], [7, 8]]
+ c = dot(a, b)
+ c -> 0 -> 0
+ """)
+ assert interp.results[0].value == 19
+
def test_flat_iter(self):
interp = self.run('''
a = |30|
diff --git a/pypy/module/micronumpy/test/test_numarray.py
b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -559,6 +559,7 @@
b = a * a
for i in range(5):
assert b[i] == i * i
+ assert b.dtype is a.dtype
a = _numpypy.array(range(5), dtype=bool)
b = a * a
@@ -784,8 +785,8 @@
def test_sum(self):
from _numpypy import array
a = array(range(5))
- assert a.sum() == 10.0
- assert a[:4].sum() == 6.0
+ assert a.sum() == 10
+ assert a[:4].sum() == 6
a = array([True] * 5, bool)
assert a.sum() == 5
@@ -910,15 +911,36 @@
assert c.any() == False
def test_dot(self):
- from _numpypy import array, dot
+ from _numpypy import array, dot, arange
a = array(range(5))
- assert a.dot(a) == 30.0
+ assert dot(a, a) == 30.0
a = array(range(5))
assert a.dot(range(5)) == 30
assert dot(range(5), range(5)) == 30
assert (dot(5, [1, 2, 3]) == [5, 10, 15]).all()
+ a = arange(12).reshape(3, 4)
+ b = arange(12).reshape(4, 3)
+ c = a.dot(b)
+ assert (c == [[ 42, 48, 54], [114, 136, 158], [186, 224, 262]]).all()
+
+ a = arange(24).reshape(2, 3, 4)
+ raises(ValueError, "a.dot(a)")
+ b = a[0, :, :].T
+ #Superfluous shape test makes the intention of the test clearer
+ assert a.shape == (2, 3, 4)
+ assert b.shape == (4, 3)
+ c = dot(a, b)
+ assert (c == [[[14, 38, 62], [38, 126, 214], [62, 214, 366]],
+ [[86, 302, 518], [110, 390, 670], [134, 478, 822]]]).all()
+ c = dot(a, b[:, 2])
+ assert (c == [[62, 214, 366], [518, 670, 822]]).all()
+ a = arange(3*4*5*6).reshape((3,4,5,6))
+ b = arange(3*4*5*6)[::-1].reshape((5,4,6,3))
+ assert dot(a, b)[2,3,2,1,2,2] == 499128
+ assert sum(a[2,3,2,:] * b[1,2,:,2]) == 499128
+
def test_dot_constant(self):
from _numpypy import array
a = array(range(5))
diff --git a/pypy/module/micronumpy/test/test_zjit.py
b/pypy/module/micronumpy/test/test_zjit.py
--- a/pypy/module/micronumpy/test/test_zjit.py
+++ b/pypy/module/micronumpy/test/test_zjit.py
@@ -456,6 +456,30 @@
'int_rshift': 1,
})
+ def define_dot():
+ return """
+ a = [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]
+ b=[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]]
+ c = dot(a, b)
+ c -> 1 -> 2
+ """
+
+ def test_dot(self):
+ result = self.run("dot")
+ assert result == 184
+ self.check_simple_loop({'arraylen_gc': 9,
+ 'float_add': 1,
+ 'float_mul': 1,
+ 'getinteriorfield_raw': 3,
+ 'guard_false': 3,
+ 'guard_true': 3,
+ 'int_add': 6,
+ 'int_lt': 6,
+ 'int_sub': 3,
+ 'jump': 1,
+ 'setinteriorfield_raw': 1})
+
+
class TestNumpyOld(LLJitMixin):
def setup_class(cls):
py.test.skip("old")
_______________________________________________
pypy-commit mailing list
[email protected]
http://mail.python.org/mailman/listinfo/pypy-commit