Author: mattip Branch: matrixmath-dot Changeset: r51432:f62709780578 Date: 2012-01-18 02:22 +0200 http://bitbucket.org/pypy/pypy/changeset/f62709780578/
Log: passes a test, needs cleanup diff --git a/pypy/module/micronumpy/interp_iter.py b/pypy/module/micronumpy/interp_iter.py --- a/pypy/module/micronumpy/interp_iter.py +++ b/pypy/module/micronumpy/interp_iter.py @@ -2,7 +2,7 @@ from pypy.rlib import jit from pypy.rlib.objectmodel import instantiate from pypy.module.micronumpy.strides import calculate_broadcast_strides,\ - calculate_slice_strides + calculate_slice_strides, calculate_dot_strides class BaseTransform(object): pass @@ -16,6 +16,11 @@ def __init__(self, res_shape): self.res_shape = res_shape +class DotTransform(BaseTransform): + def __init__(self, res_shape, skip_dims): + self.res_shape = res_shape + self.skip_dims = skip_dims + class BaseIterator(object): def next(self, shapelen): raise NotImplementedError @@ -85,6 +90,10 @@ self.strides, self.backstrides, t.chunks) return ViewIterator(r[1], r[2], r[3], r[0]) + elif isinstance(t, DotTransform): + r = calculate_dot_strides(self.strides, self.backstrides, + t.res_shape, t.skip_dims) + return ViewIterator(self.offset, r[0], r[1], t.res_shape) @jit.unroll_safe def next(self, shapelen): @@ -130,6 +139,7 @@ def transform(self, arr, t): pass + class AxisIterator(BaseIterator): def __init__(self, start, dim, shape, strides, backstrides): self.res_shape = shape[:] diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py --- a/pypy/module/micronumpy/interp_numarray.py +++ b/pypy/module/micronumpy/interp_numarray.py @@ -3,13 +3,14 @@ from pypy.interpreter.gateway import interp2app, NoneNotWrapped from pypy.interpreter.typedef import TypeDef, GetSetProperty from pypy.module.micronumpy import interp_ufuncs, interp_dtype, signature -from pypy.module.micronumpy.strides import calculate_slice_strides +from pypy.module.micronumpy.strides import calculate_slice_strides,\ + calculate_dot_strides from pypy.rlib import jit from pypy.rpython.lltypesystem import lltype, rffi from pypy.tool.sourcetools import func_with_new_name from pypy.rlib.rstring import StringBuilder from pypy.module.micronumpy.interp_iter import ArrayIterator, OneDimIterator,\ - SkipLastAxisIterator + SkipLastAxisIterator, ViewIterator numpy_driver = jit.JitDriver( greens=['shapelen', 'sig'], @@ -211,6 +212,28 @@ n_old_elems_to_use *= old_shape[oldI] return new_strides +def match_dot_shapes(space, self, other): + my_critical_dim_size = self.shape[-1] + other_critical_dim_size = other.shape[0] + other_critical_dim = 0 + other_critical_dim_stride = other.strides[0] + out_shape = [] + if len(other.shape) > 1: + other_critical_dim = len(other.shape) - 2 + other_critical_dim_size = other.shape[other_critical_dim] + other_critical_dim_stride = other.strides[other_critical_dim] + assert other_critical_dim >= 0 + out_shape += self.shape[:-1] + \ + other.shape[0:other_critical_dim] + \ + other.shape[other_critical_dim + 1:] + elif len(other.shape) > 0: + #dot does not reduce for scalars + out_shape += self.shape[:-1] + if my_critical_dim_size != other_critical_dim_size: + raise OperationError(space.w_ValueError, space.wrap( + "objects are not aligned")) + return out_shape, other_critical_dim + class BaseArray(Wrappable): _attrs_ = ["invalidates", "shape", 'size'] @@ -384,70 +407,62 @@ the second-to-last of `b`:: dot(a, b)[i,j,k,m] = sum(a[i,j,:] * b[k,:,m])''' - w_other = convert_to_array(space, w_other) - if isinstance(w_other, Scalar): - return self.descr_mul(space, w_other) - elif len(self.shape) < 2 and len(w_other.shape) < 2: - w_res = self.descr_mul(space, w_other) + other = convert_to_array(space, w_other) + if isinstance(other, Scalar): + return self.descr_mul(space, other) + elif len(self.shape) < 2 and len(other.shape) < 2: + w_res = self.descr_mul(space, other) assert isinstance(w_res, BaseArray) return w_res.descr_sum(space, space.wrap(-1)) dtype = interp_ufuncs.find_binop_result_dtype(space, - self.find_dtype(), w_other.find_dtype()) - if self.size < 1 and w_other.size < 1: + self.find_dtype(), other.find_dtype()) + if self.size < 1 and other.size < 1: #numpy compatability return scalar_w(space, dtype, space.wrap(0)) #Do the dims match? - my_critical_dim_size = self.shape[-1] - other_critical_dim_size = w_other.shape[0] - other_critical_dim = 0 - other_critical_dim_stride = w_other.strides[0] - out_shape = [] - if len(w_other.shape) > 1: - other_critical_dim = len(w_other.shape) - 2 - other_critical_dim_size = w_other.shape[other_critical_dim] - other_critical_dim_stride = w_other.strides[other_critical_dim] - assert other_critical_dim >= 0 - out_shape += self.shape[:-1] + \ - w_other.shape[0:other_critical_dim] + \ - w_other.shape[other_critical_dim + 1:] - elif len(w_other.shape) > 0: - #dot does not reduce for scalars - out_shape += self.shape[:-1] - if my_critical_dim_size != other_critical_dim_size: - raise OperationError(space.w_ValueError, space.wrap( - "objects are not aligned")) + out_shape, other_critical_dim = match_dot_shapes(space, self, other) out_size = 1 - for os in out_shape: - out_size *= os - out_ndims = len(out_shape) - # TODO: what should the order be? C or F? - arr = W_NDimArray(out_size, out_shape, dtype=dtype) - # TODO: this is all a bogus mess of previous work, - # rework within the context of transformations - ''' - out_iter = ViewIterator(arr.start, arr.strides, arr.backstrides, arr.shape) - # TODO: invalidate self, w_other with arr ? - while not out_iter.done(): - my_index = self.start - other_index = w_other.start - i = 0 - while i < len(self.shape) - 1: - my_index += out_iter.indices[i] * self.strides[i] - i += 1 - for j in range(len(w_other.shape) - 2): - other_index += out_iter.indices[i] * w_other.strides[j] - other_index += out_iter.indices[-1] * w_other.strides[-1] - w_ssd = space.newlist([space.wrap(my_index), - space.wrap(len(self.shape) - 1)]) - w_osd = space.newlist([space.wrap(other_index), - space.wrap(other_critical_dim)]) - w_res = self.descr_mul(space, w_other) - assert isinstance(w_res, BaseArray) - value = w_res.descr_sum(space) - arr.setitem(out_iter.get_offset(), value) - out_iter = out_iter.next(out_ndims) - ''' - return arr + for o in out_shape: + out_size *= o + result = W_NDimArray(out_size, out_shape, dtype) + # given a.shape == [3, 5, 7], + # b.shape == [2, 7, 4] + # result.shape == [3, 5, 2, 4] + # all iterators shapes should be [3, 5, 2, 7, 4] + # result should skip dims 3 which is results.ndims - 1 + # a should skip 2, 4 which is a.ndims-1 + range(b.ndims) + # except where it==(b.ndims-2) + # b should skip 0, 1 + mul = interp_ufuncs.get(space).multiply.func + add = interp_ufuncs.get(space).add.func + broadcast_shape = self.shape[:-1] + other.shape + #Aww, cmon, this is the product of a warped mind. + left_skip = [len(self.shape) - 1 + i for i in range(len(other.shape)) if i != other_critical_dim] + right_skip = range(len(self.shape) - 1) + arr = DotArray(mul, 'DotName', out_shape, dtype, self, other, + left_skip, right_skip) + arr.broadcast_shape = broadcast_shape + arr.result_skip = [len(out_shape) - 1] + #Make this lazy someday... + sig = signature.find_sig(signature.DotSignature(mul, 'dot', dtype, + self.create_sig(), other.create_sig()), arr) + assert isinstance(sig, signature.DotSignature) + self.do_dot_loop(sig, result, arr, add) + return result + + def do_dot_loop(self, sig, result, arr, add): + frame = sig.create_frame(arr) + shapelen = len(arr.broadcast_shape) + _r = calculate_dot_strides(result.strides, result.backstrides, + arr.broadcast_shape, arr.result_skip) + ri = ViewIterator(0, _r[0], _r[1], arr.broadcast_shape) + while not frame.done(): + v = sig.eval(frame, arr).convert_to(sig.calc_dtype) + z = result.getitem(ri.offset) + value = add(sig.calc_dtype, v, result.getitem(ri.offset)) + result.setitem(ri.offset, value) + frame.next(shapelen) + ri = ri.next(shapelen) def get_concrete(self): raise NotImplementedError @@ -919,6 +934,23 @@ left, right) self.dim = dim +class DotArray(Call2): + """ NOTE: this is only used as a container, you should never + encounter such things in the wild. Remove this comment + when we'll make Dot lazy + """ + _immutable_fields_ = ['left', 'right'] + + def __init__(self, ufunc, name, shape, dtype, left, right, left_skip, right_skip): + Call2.__init__(self, ufunc, name, shape, dtype, dtype, + left, right) + self.left_skip = left_skip + self.right_skip = right_skip + def create_sig(self): + #if self.forced_result is not None: + # return self.forced_result.create_sig() + assert NotImplementedError + class ConcreteArray(BaseArray): """ An array that have actual storage, whether owned or not """ diff --git a/pypy/module/micronumpy/interp_ufuncs.py b/pypy/module/micronumpy/interp_ufuncs.py --- a/pypy/module/micronumpy/interp_ufuncs.py +++ b/pypy/module/micronumpy/interp_ufuncs.py @@ -192,17 +192,17 @@ sig=sig, identity=identity, shapelen=shapelen, arr=arr) - iter = frame.get_final_iter() + iterator = frame.get_final_iter() v = sig.eval(frame, arr).convert_to(sig.calc_dtype) - if iter.first_line: + if iterator.first_line: if identity is not None: value = self.func(sig.calc_dtype, identity, v) else: value = v else: - cur = arr.left.getitem(iter.offset) + cur = arr.left.getitem(iterator.offset) value = self.func(sig.calc_dtype, cur, v) - arr.left.setitem(iter.offset, value) + arr.left.setitem(iterator.offset, value) frame.next(shapelen) def reduce_loop(self, shapelen, sig, frame, value, obj, dtype): diff --git a/pypy/module/micronumpy/signature.py b/pypy/module/micronumpy/signature.py --- a/pypy/module/micronumpy/signature.py +++ b/pypy/module/micronumpy/signature.py @@ -2,7 +2,7 @@ from pypy.rlib.rarithmetic import intmask from pypy.module.micronumpy.interp_iter import ViewIterator, ArrayIterator, \ ConstantIterator, AxisIterator, ViewTransform,\ - BroadcastTransform + BroadcastTransform, DotTransform from pypy.rlib.jit import hint, unroll_safe, promote """ Signature specifies both the numpy expression that has been constructed @@ -331,7 +331,6 @@ assert isinstance(arr, Call2) lhs = self.left.eval(frame, arr.left).convert_to(self.calc_dtype) rhs = self.right.eval(frame, arr.right).convert_to(self.calc_dtype) - return self.binfunc(self.calc_dtype, lhs, rhs) def debug_repr(self): @@ -450,3 +449,21 @@ def debug_repr(self): return 'AxisReduceSig(%s, %s)' % (self.name, self.right.debug_repr()) + +class DotSignature(Call2): + def _invent_numbering(self, cache, allnumbers): + self.left._invent_numbering(new_cache(), allnumbers) + self.right._invent_numbering(new_cache(), allnumbers) + + def _create_iter(self, iterlist, arraylist, arr, transforms): + from pypy.module.micronumpy.interp_numarray import DotArray + + assert isinstance(arr, DotArray) + rtransforms = transforms + [DotTransform(arr.broadcast_shape, arr.right_skip)] + ltransforms = transforms + [DotTransform(arr.broadcast_shape, arr.left_skip)] + self.left._create_iter(iterlist, arraylist, arr.left, ltransforms) + self.right._create_iter(iterlist, arraylist, arr.right, rtransforms) + + def debug_repr(self): + return 'DotSig(%s, %s %s)' % (self.name, self.right.debug_repr(), + self.left.debug_repr()) diff --git a/pypy/module/micronumpy/strides.py b/pypy/module/micronumpy/strides.py --- a/pypy/module/micronumpy/strides.py +++ b/pypy/module/micronumpy/strides.py @@ -37,3 +37,17 @@ rstrides = [0] * (len(res_shape) - len(orig_shape)) + rstrides rbackstrides = [0] * (len(res_shape) - len(orig_shape)) + rbackstrides return rstrides, rbackstrides + +def calculate_dot_strides(strides, backstrides, res_shape, skip_dims): + rstrides = [] + rbackstrides = [] + j=0 + for i in range(len(res_shape)): + if i in skip_dims: + rstrides.append(0) + rbackstrides.append(0) + else: + rstrides.append(strides[j]) + rbackstrides.append(backstrides[j]) + j += 1 + return rstrides, rbackstrides diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py --- a/pypy/module/micronumpy/test/test_numarray.py +++ b/pypy/module/micronumpy/test/test_numarray.py @@ -867,7 +867,7 @@ assert c.any() == False def test_dot(self): - from _numpypy import array, dot + from _numpypy import array, dot, arange a = array(range(5)) assert a.dot(a) == 30.0 @@ -876,13 +876,12 @@ assert dot(range(5), range(5)) == 30 assert (dot(5, [1, 2, 3]) == [5, 10, 15]).all() - a = array([range(4), range(4, 8), range(8, 12)]) - b = array([range(3), range(3, 6), range(6, 9), range(9, 12)]) + a = arange(12).reshape(3, 4) + b = arange(12).reshape(4, 3) c = a.dot(b) assert (c == [[ 42, 48, 54], [114, 136, 158], [186, 224, 262]]).all() - a = array([[range(4), range(4, 8), range(8, 12)], - [range(12, 16), range(16, 20), range(20, 24)]]) + a = arange(24).reshape(2, 3, 4) raises(ValueError, "a.dot(a)") b = a[0, :, :].T #Superfluous shape test makes the intention of the test clearer _______________________________________________ pypy-commit mailing list pypy-commit@python.org http://mail.python.org/mailman/listinfo/pypy-commit