Author: Matti Picus <matti.pi...@gmail.com>
Branch: 
Changeset: r67839:4b5d0c9d1e79
Date: 2013-11-04 23:07 +0200
http://bitbucket.org/pypy/pypy/changeset/4b5d0c9d1e79/

Log:    add out to np.dot and ndarray.dot

diff --git a/pypy/module/micronumpy/interp_arrayops.py 
b/pypy/module/micronumpy/interp_arrayops.py
--- a/pypy/module/micronumpy/interp_arrayops.py
+++ b/pypy/module/micronumpy/interp_arrayops.py
@@ -91,11 +91,11 @@
     out = W_NDimArray.from_shape(space, shape, dtype)
     return loop.where(out, shape, arr, x, y, dtype)
 
-def dot(space, w_obj1, w_obj2):
+def dot(space, w_obj1, w_obj2, w_out=None):
     w_arr = convert_to_array(space, w_obj1)
     if w_arr.is_scalar():
-        return convert_to_array(space, w_obj2).descr_dot(space, w_arr)
-    return w_arr.descr_dot(space, w_obj2)
+        return convert_to_array(space, w_obj2).descr_dot(space, w_arr, w_out)
+    return w_arr.descr_dot(space, w_obj2, w_out)
 
 
 @unwrap_spec(axis=int)
diff --git a/pypy/module/micronumpy/interp_numarray.py 
b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -853,7 +853,14 @@
         w_remainder = self.descr_rmod(space, w_other)
         return space.newtuple([w_quotient, w_remainder])
 
-    def descr_dot(self, space, w_other):
+    def descr_dot(self, space, w_other, w_out=None):
+        if space.is_none(w_out):
+            out = None
+        elif not isinstance(w_out, W_NDimArray):
+            raise OperationError(space.w_TypeError, space.wrap(
+                    'output must be an array'))
+        else:
+            out = w_out
         other = convert_to_array(space, w_other)
         if other.is_scalar():
             #Note: w_out is not modified, this is numpy compliant.
@@ -861,7 +868,7 @@
         elif len(self.get_shape()) < 2 and len(other.get_shape()) < 2:
             w_res = self.descr_mul(space, other)
             assert isinstance(w_res, W_NDimArray)
-            return w_res.descr_sum(space, space.wrap(-1))
+            return w_res.descr_sum(space, space.wrap(-1), out)
         dtype = interp_ufuncs.find_binop_result_dtype(space,
                                      self.get_dtype(), other.get_dtype())
         if self.get_size() < 1 and other.get_size() < 1:
@@ -869,7 +876,25 @@
             return W_NDimArray.new_scalar(space, dtype, space.wrap(0))
         # Do the dims match?
         out_shape, other_critical_dim = _match_dot_shapes(space, self, other)
-        w_res = W_NDimArray.from_shape(space, out_shape, dtype, 
w_instance=self)
+        if out:
+            matches = True
+            if len(out.get_shape()) != len(out_shape):
+                matches = False
+            else:
+                for i in range(len(out_shape)):
+                    if out.get_shape()[i] != out_shape[i]:
+                        matches = False
+                        break
+            if dtype != out.get_dtype():
+                matches = False
+            if not out.implementation.order == "C":
+                matches = False
+            if not matches:
+                raise OperationError(space.w_ValueError, space.wrap(
+                    'output array is not acceptable (must have the right type, 
nr dimensions, and be a C-Array)'))
+            w_res = out
+        else:
+            w_res = W_NDimArray.from_shape(space, out_shape, dtype, 
w_instance=self)
         # This is the place to add fpypy and blas
         return loop.multidim_dot(space, self, other,  w_res, dtype,
                                  other_critical_dim)
diff --git a/pypy/module/micronumpy/test/test_arrayops.py 
b/pypy/module/micronumpy/test/test_arrayops.py
--- a/pypy/module/micronumpy/test/test_arrayops.py
+++ b/pypy/module/micronumpy/test/test_arrayops.py
@@ -84,6 +84,17 @@
         c = array(3.0).dot(array(4))
         assert c == 12.0
 
+    def test_dot_out(self):
+        from numpypy import arange, dot
+        a = arange(12).reshape(3, 4)
+        b = arange(12).reshape(4, 3)
+        out = arange(9).reshape(3, 3)
+        c = dot(a, b, out=out)
+        assert (c == out).all()
+        out = arange(9,dtype=float).reshape(3, 3)
+        exc = raises(ValueError, dot, a, b, out)
+        assert exc.value[0].find('not acceptable') > 0
+
     def test_choose_basic(self):
         from numpypy import array
         a, b, c = array([1, 2, 3]), array([4, 5, 6]), array([7, 8, 9])
_______________________________________________
pypy-commit mailing list
pypy-commit@python.org
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to