Author: Matti Picus <[email protected]>
Branch:
Changeset: r67839:4b5d0c9d1e79
Date: 2013-11-04 23:07 +0200
http://bitbucket.org/pypy/pypy/changeset/4b5d0c9d1e79/
Log: add out to np.dot and ndarray.dot
diff --git a/pypy/module/micronumpy/interp_arrayops.py
b/pypy/module/micronumpy/interp_arrayops.py
--- a/pypy/module/micronumpy/interp_arrayops.py
+++ b/pypy/module/micronumpy/interp_arrayops.py
@@ -91,11 +91,11 @@
out = W_NDimArray.from_shape(space, shape, dtype)
return loop.where(out, shape, arr, x, y, dtype)
-def dot(space, w_obj1, w_obj2):
+def dot(space, w_obj1, w_obj2, w_out=None):
w_arr = convert_to_array(space, w_obj1)
if w_arr.is_scalar():
- return convert_to_array(space, w_obj2).descr_dot(space, w_arr)
- return w_arr.descr_dot(space, w_obj2)
+ return convert_to_array(space, w_obj2).descr_dot(space, w_arr, w_out)
+ return w_arr.descr_dot(space, w_obj2, w_out)
@unwrap_spec(axis=int)
diff --git a/pypy/module/micronumpy/interp_numarray.py
b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -853,7 +853,14 @@
w_remainder = self.descr_rmod(space, w_other)
return space.newtuple([w_quotient, w_remainder])
- def descr_dot(self, space, w_other):
+ def descr_dot(self, space, w_other, w_out=None):
+ if space.is_none(w_out):
+ out = None
+ elif not isinstance(w_out, W_NDimArray):
+ raise OperationError(space.w_TypeError, space.wrap(
+ 'output must be an array'))
+ else:
+ out = w_out
other = convert_to_array(space, w_other)
if other.is_scalar():
#Note: w_out is not modified, this is numpy compliant.
@@ -861,7 +868,7 @@
elif len(self.get_shape()) < 2 and len(other.get_shape()) < 2:
w_res = self.descr_mul(space, other)
assert isinstance(w_res, W_NDimArray)
- return w_res.descr_sum(space, space.wrap(-1))
+ return w_res.descr_sum(space, space.wrap(-1), out)
dtype = interp_ufuncs.find_binop_result_dtype(space,
self.get_dtype(), other.get_dtype())
if self.get_size() < 1 and other.get_size() < 1:
@@ -869,7 +876,25 @@
return W_NDimArray.new_scalar(space, dtype, space.wrap(0))
# Do the dims match?
out_shape, other_critical_dim = _match_dot_shapes(space, self, other)
- w_res = W_NDimArray.from_shape(space, out_shape, dtype,
w_instance=self)
+ if out:
+ matches = True
+ if len(out.get_shape()) != len(out_shape):
+ matches = False
+ else:
+ for i in range(len(out_shape)):
+ if out.get_shape()[i] != out_shape[i]:
+ matches = False
+ break
+ if dtype != out.get_dtype():
+ matches = False
+ if not out.implementation.order == "C":
+ matches = False
+ if not matches:
+ raise OperationError(space.w_ValueError, space.wrap(
+ 'output array is not acceptable (must have the right type,
nr dimensions, and be a C-Array)'))
+ w_res = out
+ else:
+ w_res = W_NDimArray.from_shape(space, out_shape, dtype,
w_instance=self)
# This is the place to add fpypy and blas
return loop.multidim_dot(space, self, other, w_res, dtype,
other_critical_dim)
diff --git a/pypy/module/micronumpy/test/test_arrayops.py
b/pypy/module/micronumpy/test/test_arrayops.py
--- a/pypy/module/micronumpy/test/test_arrayops.py
+++ b/pypy/module/micronumpy/test/test_arrayops.py
@@ -84,6 +84,17 @@
c = array(3.0).dot(array(4))
assert c == 12.0
+ def test_dot_out(self):
+ from numpypy import arange, dot
+ a = arange(12).reshape(3, 4)
+ b = arange(12).reshape(4, 3)
+ out = arange(9).reshape(3, 3)
+ c = dot(a, b, out=out)
+ assert (c == out).all()
+ out = arange(9,dtype=float).reshape(3, 3)
+ exc = raises(ValueError, dot, a, b, out)
+ assert exc.value[0].find('not acceptable') > 0
+
def test_choose_basic(self):
from numpypy import array
a, b, c = array([1, 2, 3]), array([4, 5, 6]), array([7, 8, 9])
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit