Author: Brian Kearns <bdkea...@gmail.com> Branch: Changeset: r69530:e431aa28d934 Date: 2014-02-27 20:01 -0500 http://bitbucket.org/pypy/pypy/changeset/e431aa28d934/
Log: optimize multidim_dot loop diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py --- a/pypy/module/micronumpy/loop.py +++ b/pypy/module/micronumpy/loop.py @@ -8,7 +8,8 @@ from rpython.rtyper.lltypesystem import lltype, rffi from pypy.module.micronumpy import support, constants as NPY from pypy.module.micronumpy.base import W_NDimArray -from pypy.module.micronumpy.iterators import PureShapeIter, AxisIter +from pypy.module.micronumpy.iterators import PureShapeIter, AxisIter, \ + AllButAxisIter call2_driver = jit.JitDriver(name='numpy_call2', @@ -259,7 +260,6 @@ argmin = _new_argmin_argmax('min') argmax = _new_argmin_argmax('max') -# note that shapelen == 2 always dot_driver = jit.JitDriver(name = 'numpy_dot', greens = ['dtype'], reds = 'auto') @@ -280,25 +280,30 @@ ''' left_shape = left.get_shape() right_shape = right.get_shape() - broadcast_shape = left_shape[:-1] + right_shape - left_skip = [len(left_shape) - 1 + i for i in range(len(right_shape)) - if i != right_critical_dim] - right_skip = range(len(left_shape) - 1) - result_skip = [len(result.get_shape()) - (len(right_shape) > 1)] + assert left_shape[-1] == right_shape[right_critical_dim] assert result.get_dtype() == dtype - outi = result.implementation.create_dot_iter(broadcast_shape, result_skip) - lefti = left.implementation.create_dot_iter(broadcast_shape, left_skip) - righti = right.implementation.create_dot_iter(broadcast_shape, right_skip) - while not outi.done(): - dot_driver.jit_merge_point(dtype=dtype) - lval = lefti.getitem().convert_to(space, dtype) - rval = righti.getitem().convert_to(space, dtype) - outval = outi.getitem() - v = dtype.itemtype.mul(lval, rval) - v = dtype.itemtype.add(v, outval) - outi.setitem(v) - outi.next() - righti.next() + outi = result.create_iter() + lefti = AllButAxisIter(left.implementation, len(left_shape) - 1) + righti = AllButAxisIter(right.implementation, right_critical_dim) + n = left.implementation.shape[-1] + s1 = left.implementation.strides[-1] + s2 = right.implementation.strides[right_critical_dim] + while not lefti.done(): + while not righti.done(): + oval = outi.getitem() + i1 = lefti.offset + i2 = righti.offset + for _ in xrange(n): + dot_driver.jit_merge_point(dtype=dtype) + lval = left.implementation.getitem(i1).convert_to(space, dtype) + rval = right.implementation.getitem(i2).convert_to(space, dtype) + oval = dtype.itemtype.add(oval, dtype.itemtype.mul(lval, rval)) + i1 += s1 + i2 += s2 + outi.setitem(oval) + outi.next() + righti.next() + righti.reset() lefti.next() return result diff --git a/pypy/module/micronumpy/test/test_arrayops.py b/pypy/module/micronumpy/test/test_arrayops.py --- a/pypy/module/micronumpy/test/test_arrayops.py +++ b/pypy/module/micronumpy/test/test_arrayops.py @@ -41,8 +41,7 @@ a[0] = 0 assert (b == [1, 1, 1, 0, 0]).all() - - def test_dot(self): + def test_dot_basic(self): from numpypy import array, dot, arange a = array(range(5)) assert dot(a, a) == 30.0 @@ -69,7 +68,7 @@ assert b.shape == (4, 3) c = dot(a, b) assert (c == [[[14, 38, 62], [38, 126, 214], [62, 214, 366]], - [[86, 302, 518], [110, 390, 670], [134, 478, 822]]]).all() + [[86, 302, 518], [110, 390, 670], [134, 478, 822]]]).all() c = dot(a, b[:, 2]) assert (c == [[62, 214, 366], [518, 670, 822]]).all() a = arange(3*2*6).reshape((3,2,6)) _______________________________________________ pypy-commit mailing list pypy-commit@python.org https://mail.python.org/mailman/listinfo/pypy-commit