Author: Matti Picus <matti.pi...@gmail.com> Branch: Changeset: r67839:4b5d0c9d1e79 Date: 2013-11-04 23:07 +0200 http://bitbucket.org/pypy/pypy/changeset/4b5d0c9d1e79/
Log: add out to np.dot and ndarray.dot diff --git a/pypy/module/micronumpy/interp_arrayops.py b/pypy/module/micronumpy/interp_arrayops.py --- a/pypy/module/micronumpy/interp_arrayops.py +++ b/pypy/module/micronumpy/interp_arrayops.py @@ -91,11 +91,11 @@ out = W_NDimArray.from_shape(space, shape, dtype) return loop.where(out, shape, arr, x, y, dtype) -def dot(space, w_obj1, w_obj2): +def dot(space, w_obj1, w_obj2, w_out=None): w_arr = convert_to_array(space, w_obj1) if w_arr.is_scalar(): - return convert_to_array(space, w_obj2).descr_dot(space, w_arr) - return w_arr.descr_dot(space, w_obj2) + return convert_to_array(space, w_obj2).descr_dot(space, w_arr, w_out) + return w_arr.descr_dot(space, w_obj2, w_out) @unwrap_spec(axis=int) diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py --- a/pypy/module/micronumpy/interp_numarray.py +++ b/pypy/module/micronumpy/interp_numarray.py @@ -853,7 +853,14 @@ w_remainder = self.descr_rmod(space, w_other) return space.newtuple([w_quotient, w_remainder]) - def descr_dot(self, space, w_other): + def descr_dot(self, space, w_other, w_out=None): + if space.is_none(w_out): + out = None + elif not isinstance(w_out, W_NDimArray): + raise OperationError(space.w_TypeError, space.wrap( + 'output must be an array')) + else: + out = w_out other = convert_to_array(space, w_other) if other.is_scalar(): #Note: w_out is not modified, this is numpy compliant. @@ -861,7 +868,7 @@ elif len(self.get_shape()) < 2 and len(other.get_shape()) < 2: w_res = self.descr_mul(space, other) assert isinstance(w_res, W_NDimArray) - return w_res.descr_sum(space, space.wrap(-1)) + return w_res.descr_sum(space, space.wrap(-1), out) dtype = interp_ufuncs.find_binop_result_dtype(space, self.get_dtype(), other.get_dtype()) if self.get_size() < 1 and other.get_size() < 1: @@ -869,7 +876,25 @@ return W_NDimArray.new_scalar(space, dtype, space.wrap(0)) # Do the dims match? out_shape, other_critical_dim = _match_dot_shapes(space, self, other) - w_res = W_NDimArray.from_shape(space, out_shape, dtype, w_instance=self) + if out: + matches = True + if len(out.get_shape()) != len(out_shape): + matches = False + else: + for i in range(len(out_shape)): + if out.get_shape()[i] != out_shape[i]: + matches = False + break + if dtype != out.get_dtype(): + matches = False + if not out.implementation.order == "C": + matches = False + if not matches: + raise OperationError(space.w_ValueError, space.wrap( + 'output array is not acceptable (must have the right type, nr dimensions, and be a C-Array)')) + w_res = out + else: + w_res = W_NDimArray.from_shape(space, out_shape, dtype, w_instance=self) # This is the place to add fpypy and blas return loop.multidim_dot(space, self, other, w_res, dtype, other_critical_dim) diff --git a/pypy/module/micronumpy/test/test_arrayops.py b/pypy/module/micronumpy/test/test_arrayops.py --- a/pypy/module/micronumpy/test/test_arrayops.py +++ b/pypy/module/micronumpy/test/test_arrayops.py @@ -84,6 +84,17 @@ c = array(3.0).dot(array(4)) assert c == 12.0 + def test_dot_out(self): + from numpypy import arange, dot + a = arange(12).reshape(3, 4) + b = arange(12).reshape(4, 3) + out = arange(9).reshape(3, 3) + c = dot(a, b, out=out) + assert (c == out).all() + out = arange(9,dtype=float).reshape(3, 3) + exc = raises(ValueError, dot, a, b, out) + assert exc.value[0].find('not acceptable') > 0 + def test_choose_basic(self): from numpypy import array a, b, c = array([1, 2, 3]), array([4, 5, 6]), array([7, 8, 9]) _______________________________________________ pypy-commit mailing list pypy-commit@python.org https://mail.python.org/mailman/listinfo/pypy-commit