Author: mattip Branch: matrixmath-dot Changeset: r51506:2bcfa95fe92a Date: 2012-01-20 09:41 +0200 http://bitbucket.org/pypy/pypy/changeset/2bcfa95fe92a/
Log: refactor and rework, still need more tests diff --git a/pypy/module/micronumpy/dot.py b/pypy/module/micronumpy/dot.py new file mode 100644 --- /dev/null +++ b/pypy/module/micronumpy/dot.py @@ -0,0 +1,68 @@ +from pypy.module.micronumpy import interp_ufuncs +from pypy.module.micronumpy.strides import calculate_dot_strides +from pypy.interpreter.error import OperationError, operationerrfmt +from pypy.module.micronumpy.interp_iter import ViewIterator + + +def match_dot_shapes(space, left, right): + my_critical_dim_size = left.shape[-1] + right_critical_dim_size = right.shape[0] + right_critical_dim = 0 + right_critical_dim_stride = right.strides[0] + out_shape = [] + if len(right.shape) > 1: + right_critical_dim = len(right.shape) - 2 + right_critical_dim_size = right.shape[right_critical_dim] + right_critical_dim_stride = right.strides[right_critical_dim] + assert right_critical_dim >= 0 + out_shape += left.shape[:-1] + \ + right.shape[0:right_critical_dim] + \ + right.shape[right_critical_dim + 1:] + elif len(right.shape) > 0: + #dot does not reduce for scalars + out_shape += left.shape[:-1] + if my_critical_dim_size != right_critical_dim_size: + raise OperationError(space.w_ValueError, space.wrap( + "objects are not aligned")) + return out_shape, right_critical_dim + + +def multidim_dot(space, left, right, result, dtype, right_critical_dim): + ''' assumes left, right are concrete arrays + given left.shape == [3, 5, 7], + right.shape == [2, 7, 4] + result.shape == [3, 5, 2, 4] + broadcast shape should be [3, 5, 2, 7, 4] + result should skip dims 3 which is results.ndims - 1 + left should skip 2, 4 which is a.ndims-1 + range(right.ndims) + except where it==(right.ndims-2) + right should skip 0, 1 + ''' + mul = interp_ufuncs.get(space).multiply.func + add = interp_ufuncs.get(space).add.func + broadcast_shape = left.shape[:-1] + right.shape + left_skip = [len(left.shape) - 1 + i for i in range(len(right.shape)) + if i != right_critical_dim] + right_skip = range(len(left.shape) - 1) + result_skip = [len(result.shape) - 1] + shapelen = len(broadcast_shape) + _r = calculate_dot_strides(result.strides, result.backstrides, + broadcast_shape, result_skip) + outi = ViewIterator(0, _r[0], _r[1], broadcast_shape) + _r = calculate_dot_strides(left.strides, left.backstrides, + broadcast_shape, left_skip) + lefti = ViewIterator(0, _r[0], _r[1], broadcast_shape) + _r = calculate_dot_strides(right.strides, right.backstrides, + broadcast_shape, right_skip) + righti = ViewIterator(0, _r[0], _r[1], broadcast_shape) + while not outi.done(): + v = mul(dtype, left.getitem(lefti.offset), + right.getitem(righti.offset)) + value = add(dtype, v, result.getitem(outi.offset)) + result.setitem(outi.offset, value) + outi = outi.next(shapelen) + righti = righti.next(shapelen) + lefti = lefti.next(shapelen) + assert lefti.done() + assert righti.done() + return result diff --git a/pypy/module/micronumpy/interp_iter.py b/pypy/module/micronumpy/interp_iter.py --- a/pypy/module/micronumpy/interp_iter.py +++ b/pypy/module/micronumpy/interp_iter.py @@ -16,10 +16,6 @@ def __init__(self, res_shape): self.res_shape = res_shape -class DotTransform(BaseTransform): - def __init__(self, res_shape, skip_dims): - self.res_shape = res_shape - self.skip_dims = skip_dims class BaseIterator(object): def next(self, shapelen): @@ -90,10 +86,6 @@ self.strides, self.backstrides, t.chunks) return ViewIterator(r[1], r[2], r[3], r[0]) - elif isinstance(t, DotTransform): - r = calculate_dot_strides(self.strides, self.backstrides, - t.res_shape, t.skip_dims) - return ViewIterator(self.offset, r[0], r[1], t.res_shape) @jit.unroll_safe def next(self, shapelen): diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py --- a/pypy/module/micronumpy/interp_numarray.py +++ b/pypy/module/micronumpy/interp_numarray.py @@ -3,14 +3,15 @@ from pypy.interpreter.gateway import interp2app, NoneNotWrapped from pypy.interpreter.typedef import TypeDef, GetSetProperty from pypy.module.micronumpy import interp_ufuncs, interp_dtype, signature -from pypy.module.micronumpy.strides import calculate_slice_strides,\ - calculate_dot_strides +from pypy.module.micronumpy.strides import calculate_slice_strides from pypy.rlib import jit from pypy.rpython.lltypesystem import lltype, rffi from pypy.tool.sourcetools import func_with_new_name from pypy.rlib.rstring import StringBuilder from pypy.module.micronumpy.interp_iter import ArrayIterator, OneDimIterator,\ - SkipLastAxisIterator, ViewIterator + SkipLastAxisIterator +from pypy.module.micronumpy.dot import multidim_dot, match_dot_shapes, dot_docstring + numpy_driver = jit.JitDriver( greens=['shapelen', 'sig'], @@ -212,28 +213,6 @@ n_old_elems_to_use *= old_shape[oldI] return new_strides -def match_dot_shapes(space, self, other): - my_critical_dim_size = self.shape[-1] - other_critical_dim_size = other.shape[0] - other_critical_dim = 0 - other_critical_dim_stride = other.strides[0] - out_shape = [] - if len(other.shape) > 1: - other_critical_dim = len(other.shape) - 2 - other_critical_dim_size = other.shape[other_critical_dim] - other_critical_dim_stride = other.strides[other_critical_dim] - assert other_critical_dim >= 0 - out_shape += self.shape[:-1] + \ - other.shape[0:other_critical_dim] + \ - other.shape[other_critical_dim + 1:] - elif len(other.shape) > 0: - #dot does not reduce for scalars - out_shape += self.shape[:-1] - if my_critical_dim_size != other_critical_dim_size: - raise OperationError(space.w_ValueError, space.wrap( - "objects are not aligned")) - return out_shape, other_critical_dim - class BaseArray(Wrappable): _attrs_ = ["invalidates", "shape", 'size'] @@ -399,14 +378,6 @@ descr_argmin = _reduce_argmax_argmin_impl("min") def descr_dot(self, space, w_other): - '''Dot product of two arrays. - - For 2-D arrays it is equivalent to matrix multiplication, and for 1-D - arrays to inner product of vectors (without complex conjugation). For - N dimensions it is a sum product over the last axis of `a` and - the second-to-last of `b`:: - - dot(a, b)[i,j,k,m] = sum(a[i,j,:] * b[k,:,m])''' other = convert_to_array(space, w_other) if isinstance(other, Scalar): return self.descr_mul(space, other) @@ -425,43 +396,10 @@ for o in out_shape: out_size *= o result = W_NDimArray(out_size, out_shape, dtype) - # given a.shape == [3, 5, 7], - # b.shape == [2, 7, 4] - # result.shape == [3, 5, 2, 4] - # all iterators shapes should be [3, 5, 2, 7, 4] - # result should skip dims 3 which is results.ndims - 1 - # a should skip 2, 4 which is a.ndims-1 + range(b.ndims) - # except where it==(b.ndims-2) - # b should skip 0, 1 - mul = interp_ufuncs.get(space).multiply.func - add = interp_ufuncs.get(space).add.func - broadcast_shape = self.shape[:-1] + other.shape - #Aww, cmon, this is the product of a warped mind. - left_skip = [len(self.shape) - 1 + i for i in range(len(other.shape)) if i != other_critical_dim] - right_skip = range(len(self.shape) - 1) - arr = DotArray(mul, 'DotName', out_shape, dtype, self, other, - left_skip, right_skip) - arr.broadcast_shape = broadcast_shape - arr.result_skip = [len(out_shape) - 1] - #Make this lazy someday... - sig = signature.find_sig(signature.DotSignature(mul, 'dot', dtype, - self.create_sig(), other.create_sig()), arr) - assert isinstance(sig, signature.DotSignature) - self.do_dot_loop(sig, result, arr, add) - return result - - def do_dot_loop(self, sig, result, arr, add): - frame = sig.create_frame(arr) - shapelen = len(arr.broadcast_shape) - _r = calculate_dot_strides(result.strides, result.backstrides, - arr.broadcast_shape, arr.result_skip) - ri = ViewIterator(0, _r[0], _r[1], arr.broadcast_shape) - while not frame.done(): - v = sig.eval(frame, arr).convert_to(sig.calc_dtype) - value = add(sig.calc_dtype, v, result.getitem(ri.offset)) - result.setitem(ri.offset, value) - frame.next(shapelen) - ri = ri.next(shapelen) + # This is the place to add fpypy and blas + return multidim_dot(space, self.get_concrete(), + other.get_concrete(), result, dtype, + other_critical_dim) def get_concrete(self): raise NotImplementedError @@ -933,23 +871,6 @@ left, right) self.dim = dim -class DotArray(Call2): - """ NOTE: this is only used as a container, you should never - encounter such things in the wild. Remove this comment - when we'll make Dot lazy - """ - _immutable_fields_ = ['left', 'right'] - - def __init__(self, ufunc, name, shape, dtype, left, right, left_skip, right_skip): - Call2.__init__(self, ufunc, name, shape, dtype, dtype, - left, right) - self.left_skip = left_skip - self.right_skip = right_skip - def create_sig(self): - #if self.forced_result is not None: - # return self.forced_result.create_sig() - assert NotImplementedError - class ConcreteArray(BaseArray): """ An array that have actual storage, whether owned or not """ @@ -1304,6 +1225,8 @@ return space.wrap(arr) def dot(space, w_obj, w_obj2): + '''see numpypy.dot. Does not exist as an ndarray method in numpy. + ''' w_arr = convert_to_array(space, w_obj) if isinstance(w_arr, Scalar): return convert_to_array(space, w_obj2).descr_dot(space, w_arr) diff --git a/pypy/module/micronumpy/signature.py b/pypy/module/micronumpy/signature.py --- a/pypy/module/micronumpy/signature.py +++ b/pypy/module/micronumpy/signature.py @@ -2,7 +2,7 @@ from pypy.rlib.rarithmetic import intmask from pypy.module.micronumpy.interp_iter import ViewIterator, ArrayIterator, \ ConstantIterator, AxisIterator, ViewTransform,\ - BroadcastTransform, DotTransform + BroadcastTransform from pypy.rlib.jit import hint, unroll_safe, promote """ Signature specifies both the numpy expression that has been constructed @@ -449,21 +449,3 @@ def debug_repr(self): return 'AxisReduceSig(%s, %s)' % (self.name, self.right.debug_repr()) - -class DotSignature(Call2): - def _invent_numbering(self, cache, allnumbers): - self.left._invent_numbering(new_cache(), allnumbers) - self.right._invent_numbering(new_cache(), allnumbers) - - def _create_iter(self, iterlist, arraylist, arr, transforms): - from pypy.module.micronumpy.interp_numarray import DotArray - - assert isinstance(arr, DotArray) - rtransforms = transforms + [DotTransform(arr.broadcast_shape, arr.right_skip)] - ltransforms = transforms + [DotTransform(arr.broadcast_shape, arr.left_skip)] - self.left._create_iter(iterlist, arraylist, arr.left, ltransforms) - self.right._create_iter(iterlist, arraylist, arr.right, rtransforms) - - def debug_repr(self): - return 'DotSig(%s, %s %s)' % (self.name, self.right.debug_repr(), - self.left.debug_repr()) diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py --- a/pypy/module/micronumpy/test/test_numarray.py +++ b/pypy/module/micronumpy/test/test_numarray.py @@ -869,7 +869,7 @@ def test_dot(self): from _numpypy import array, dot, arange a = array(range(5)) - assert a.dot(a) == 30.0 + assert dot(a, a) == 30.0 a = array(range(5)) assert a.dot(range(5)) == 30 @@ -887,9 +887,11 @@ #Superfluous shape test makes the intention of the test clearer assert a.shape == (2, 3, 4) assert b.shape == (4, 3) - c = a.dot(b) + c = dot(a, b) assert (c == [[[14, 38, 62], [38, 126, 214], [62, 214, 366]], [[86, 302, 518], [110, 390, 670], [134, 478, 822]]]).all() + c = dot(a, b[:, :, 2]) + assert (c == [[38, 126, 214], [302, 390, 478]]).all() def test_dot_constant(self): from _numpypy import array _______________________________________________ pypy-commit mailing list pypy-commit@python.org http://mail.python.org/mailman/listinfo/pypy-commit