Author: mattip
Branch: numpy NDimArray
Changeset: r48214:1be18bf996dd
Date: 2011-10-18 23:48 +0200
http://bitbucket.org/pypy/pypy/changeset/1be18bf996dd/

Log:    Add shape checking, implement NDim binary ops (mult and div are
        wrong)

diff --git a/pypy/module/micronumpy/interp_numarray.py 
b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -107,6 +107,12 @@
 
     def _binop_impl(ufunc_name):
         def impl(self, space, w_other):
+            if not 
space.eq_w(self.descr_shape(space),w_other.descr_shape(space)):
+                raise OperationError(space.w_ValueError,
+                    space.wrap("shape mismatch: objects cannot be broadcast to 
a single shape %s <> %s" \
+                            % (self.descr_shape(space).unwrap(space), 
+                              w_other.descr_shape(space).unwrap(space),
+                            )))
             return getattr(interp_ufuncs.get(space), ufunc_name).call(space, 
[self, w_other])
         return func_with_new_name(impl, "binop_%s_impl" % ufunc_name)
 
@@ -297,7 +303,7 @@
             length = space.len_w(w_idx)
             if length > 1: # only one dimension for now.
                 raise OperationError(space.w_IndexError,
-                                     space.wrap("invalid index"))
+                                     space.wrap("invalid index: cannot handle 
tuple indices yet"))
             if length == 0:
                 w_idx = space.newslice(space.wrap(0),
                                       space.wrap(self.find_size()),
@@ -392,8 +398,12 @@
     def compute(self):
         i = 0
         signature = self.signature
+        result_shape = self.find_shape()
         result_size = self.find_size()
-        result = SingleDimArray(result_size, self.find_dtype())
+        if len(result_shape)>1:
+            result = NDimArray(result_shape, self.find_dtype())
+        else:
+            result = SingleDimArray(result_size, self.find_dtype())
         while i < result_size:
             numpy_driver.jit_merge_point(signature=signature,
                                          result_size=result_size, i=i,
@@ -419,6 +429,12 @@
     def setitem(self, item, value):
         return self.get_concrete().setitem(item, value)
 
+    def find_shape(self):
+        if self.forced_result is not None:
+            # The result has been computed and sources may be unavailable
+            return self.forced_result.find_shape()
+        return self._find_shape()
+
     def find_size(self):
         if self.forced_result is not None:
             # The result has been computed and sources may be unavailable
@@ -466,6 +482,13 @@
         self.left = None
         self.right = None
 
+    def _find_shape(self):
+        try:
+            return self.left.find_shape()
+        except ValueError:
+            pass
+        return self.right.find_shape()
+
     def _find_size(self):
         try:
             return self.left.find_size()
@@ -538,6 +561,9 @@
     def get_root_storage(self):
         return self.parent.get_concrete().get_root_storage()
 
+    def find_shape(self):
+        return (self.size,)
+
     def find_size(self):
         return self.size
 
@@ -582,6 +608,9 @@
     def get_root_storage(self):
         return self.storage
 
+    def find_shape(self):
+        return (self.size,)
+
     def find_size(self):
         return self.size
 
@@ -639,6 +668,9 @@
     def get_root_storage(self):
         return self.storage
 
+    def find_shape(self):
+        return self.shape
+
     def find_size(self):
         return self.size
 
@@ -649,7 +681,7 @@
         return self.dtype.getitem(self.storage, i)
 
     def descr_shape(self, space):
-        return self.shape
+        return space.wrap(self.shape)
 
     def descr_len(self, space):
         return space.wrap(self.size)
@@ -673,7 +705,6 @@
                 self.setitem_recurse_w(space,i+index*self.shape[-depth], 
depth+1, w_item)
                 i+=1
         else:
-            print 'setting',index,'to',w_value
             self.setitem_w(space,index,w_value) 
     def setitem(self, item, value):
         self.invalidated()
_______________________________________________
pypy-commit mailing list
pypy-commit@python.org
http://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to