Author: mattip
Branch: numpypy-axisops
Changeset: r50913:da091f2c5c7d
Date: 2011-12-26 23:32 +0200
http://bitbucket.org/pypy/pypy/changeset/da091f2c5c7d/

Log:    checkpoint

diff --git a/pypy/module/micronumpy/interp_iter.py 
b/pypy/module/micronumpy/interp_iter.py
--- a/pypy/module/micronumpy/interp_iter.py
+++ b/pypy/module/micronumpy/interp_iter.py
@@ -103,34 +103,49 @@
     def next(self, shapelen):
         return self
 
-# ------ other iterators that are not part of the computation frame ----------
+def axis_iter_from_arr(arr, dim=-1, start=[]):
+    return AxisIterator(arr.start, arr.strides, arr.backstrides, arr.shape,
+                        dim, start)
 
 class AxisIterator(object):
     """ This object will return offsets of each start of a stride on the 
         desired dimension, starting at the desired index
     """
-    def __init__(self, arr, dim=-1, start=[]):
-        self.arr = arr
+    def __init__(self, start, strides, backstrides, shape, dim=-1, start=[]):
+        self.shape = shape
         self.indices = [0] * len(arr.shape)
         self.done = False
-        self.offset = arr.start
-        self.dim = len(arr.shape) - 1
+        self.offset = start
+        self.dim = len(shape) - 1
         if dim >= 0:
             self.dim = dim
-        if len(start) == len(arr.shape):
+        if len(start) == len(shape):
             for i in range(len(start)):
-                self.offset += arr.strides[i] * start[i]
-    def next(self):
-        for i in range(len(self.arr.shape) - 1, -1, -1):
+                self.offset += strides[i] * start[i]
+    def next(self, shapelen):
+        offset = self.offset
+        indices = [0] * shapelen
+        for i in range(shapelen):
+            indices[i] = self.indices[i]
+        for i in range(shapelen - 1, -1, -1):
             if i == self.dim:
                 continue
-            if self.indices[i] < self.arr.shape[i] - 1:
-                self.indices[i] += 1
-                self.offset += self.arr.strides[i]
+            if indices[i] < self.shape[i] - 1:
+                indices[i] += 1
+                offset += self.strides[i]
                 break
             else:
-                self.indices[i] = 0
-                self.offset -= self.arr.backstrides[i]
+                indices[i] = 0
+                offset -= self.backstrides[i]
         else:
             self.done = True
-        
+        res = instantiate(AxisIterator)
+        res.offset = offset
+        res.indices = indices
+        res.strides = self.strides
+        res.backstrides = self.backstrides
+        res.shape = self.shape
+        res.dim = self.dim
+        res.done = done
+        return res
+
diff --git a/pypy/module/micronumpy/interp_numarray.py 
b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -664,7 +664,8 @@
         self.name = name
 
     def _del_sources(self):
-        # Function for deleting references to source arrays, to allow 
garbage-collecting them
+        # Function for deleting references to source arrays, 
+        # to allow garbage-collecting them
         raise NotImplementedError
 
     def compute(self):
@@ -730,6 +731,42 @@
     def _del_sources(self):
         self.child = None
 
+class Reduce(VirtualArray):
+    def __init__(self, ufunc, name, dim, res_dtype, values):
+        shape=values.shape[0:dim] + values.shape[dim+1:len(values.shape)]
+        VirtualArray.__init__(self, name, shape, res_dtype)
+        self.values = values
+        self.size = values.size
+        self.ufunc = ufunc
+        self.res_dtype = res_dtype
+        self.dim = dim
+
+    def _del_sources(self):
+        self.values = None
+
+    def create_sig(self, res_shape):
+        if self.forced_result is not None:
+            return self.forced_result.create_sig(res_shape)
+        return signature.ReduceSignature(self.ufunc, self.name, self.res_dtype,
+                           signature.ViewSignature(self.res_dtype),
+                           self.values.create_sig(res_shape))
+
+    def compute(self):
+        result = W_NDimArray(self.size, self.shape, self.find_dtype())
+        shapelen = len(result.shape)
+        sig = self.find_sig()
+        ri = ArrayIterator(self.size)
+        si = AxisIterator(self,self.dim)
+        while not ri.done():
+            frame = sig.create_frame(self, self.values, chunks = si.indices)
+            val = sig.eval(frame, self)
+            result.dtype.setitem(result.storage, ri.offset, val)
+            ri = ri.next(shapelen)
+            si = si.next(shapelen)
+        return result
+
+
+
 class Call1(VirtualArray):
     def __init__(self, ufunc, name, shape, res_dtype, values):
         VirtualArray.__init__(self, name, shape, res_dtype)
diff --git a/pypy/module/micronumpy/interp_ufuncs.py 
b/pypy/module/micronumpy/interp_ufuncs.py
--- a/pypy/module/micronumpy/interp_ufuncs.py
+++ b/pypy/module/micronumpy/interp_ufuncs.py
@@ -3,8 +3,8 @@
 from pypy.interpreter.gateway import interp2app, unwrap_spec
 from pypy.interpreter.typedef import TypeDef, GetSetProperty, 
interp_attrproperty
 from pypy.module.micronumpy import interp_boxes, interp_dtype, types
-from pypy.module.micronumpy.signature import (ReduceSignature, 
ScalarSignature, 
-                            ArraySignature, find_sig)
+from pypy.module.micronumpy.signature import (ReduceSignature, 
+                            ScalarSignature, find_sig)
 from pypy.rlib import jit
 from pypy.rlib.rarithmetic import LONG_BIT
 from pypy.tool.sourcetools import func_with_new_name
@@ -124,12 +124,10 @@
             promote_to_largest=True
         )
         shapelen = len(obj.shape)
-        if dim>=0 or 0:
-            sig = find_sig(ReduceSignature(self.func, self.name, dtype,
-                                       ArraySignature(dtype),
-                                       obj.create_sig(obj.shape)), obj)
-        else:
-            sig = find_sig(ReduceSignature(self.func, self.name, dtype,
+        if shapelen>1 and dim>=0:
+            from pypy.module.micronumpy.interp_numarray import Reduce
+            return Reduce(self.func, self.name, dim, dtype, obj)
+        sig = find_sig(ReduceSignature(self.func, self.name, dtype,
                                        ScalarSignature(dtype),
                                        obj.create_sig(obj.shape)), obj)
         frame = sig.create_frame(obj)
diff --git a/pypy/module/micronumpy/signature.py 
b/pypy/module/micronumpy/signature.py
--- a/pypy/module/micronumpy/signature.py
+++ b/pypy/module/micronumpy/signature.py
@@ -90,11 +90,11 @@
             allnumbers.append(no)
         self.iter_no = no
 
-    def create_frame(self, arr, res_shape=None):
+    def create_frame(self, arr, res_shape=None, chunks = []):
         res_shape = res_shape or arr.shape
         iterlist = []
         arraylist = []
-        self._create_iter(iterlist, arraylist, arr, res_shape, [])
+        self._create_iter(iterlist, arraylist, arr, res_shape, chunks)
         return NumpyEvalFrame(iterlist, arraylist)
 
 class ConcreteSignature(Signature):
diff --git a/pypy/module/micronumpy/test/test_ufuncs.py 
b/pypy/module/micronumpy/test/test_ufuncs.py
--- a/pypy/module/micronumpy/test/test_ufuncs.py
+++ b/pypy/module/micronumpy/test/test_ufuncs.py
@@ -338,16 +338,19 @@
         raises(ValueError, sin.reduce, [1, 2, 3])
         raises(TypeError, add.reduce, 1)
 
-    def test_reduce(self):
-        from numpypy import add, maximum, arange
-
+    def test_reduce1D(self):
+        from numpypy import add, maximum
         assert add.reduce([1, 2, 3]) == 6
         assert maximum.reduce([1]) == 1
         assert maximum.reduce([1, 2, 3]) == 3
         raises(ValueError, maximum.reduce, [])
-        a = arange(12).reshape(3,4)
-        assert add.reduce(a, 0) == add.reduce(a)
+
+    def test_reduceND(self):
+        from numpypy import add, arange
+        a = arange(12).reshape(3, 4)
+        assert add.reduce(a,1)[0] ==6
         assert (add.reduce(a, 1) == [ 6, 22, 38]).all()
+        assert (add.reduce(a, 0) == add.reduce(a)).all()
 
     def test_comparisons(self):
         import operator
_______________________________________________
pypy-commit mailing list
[email protected]
http://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to