Author: Matti Picus <[email protected]>
Branch: numpypy-nditer
Changeset: r64057:cc555ec312e2
Date: 2013-05-13 22:55 +0300
http://bitbucket.org/pypy/pypy/changeset/cc555ec312e2/

Log:    shuffle things around to handle op_flags, readwrite almost works

diff --git a/pypy/module/micronumpy/interp_nditer.py 
b/pypy/module/micronumpy/interp_nditer.py
--- a/pypy/module/micronumpy/interp_nditer.py
+++ b/pypy/module/micronumpy/interp_nditer.py
@@ -2,19 +2,118 @@
 from pypy.interpreter.typedef import TypeDef, GetSetProperty, 
make_weakref_descr
 from pypy.interpreter.gateway import interp2app, unwrap_spec, WrappedDefault
 from pypy.interpreter.error import OperationError
-from pypy.module.micronumpy.base import convert_to_array
+from pypy.module.micronumpy.base import W_NDimArray, convert_to_array
 from pypy.module.micronumpy.strides import calculate_broadcast_strides
 from pypy.module.micronumpy.iter import MultiDimViewIterator
 from pypy.module.micronumpy import support
 
+def parse_op_arg(space, name, w_op_flags, n, parse_one_arg):
+    ret = []
+    if space.is_w(w_op_flags, space.w_None):
+        for i in range(n):
+            ret.append(OpFlag())
+    elif not space.isinstance_w(w_op_flags, space.w_tuple) and not \
+             space.isinstance_w(w_op_flags, space.w_list):
+        raise OperationError(space.w_ValueError, space.wrap(
+                '%s must be a tuple or array of per-op flag-tuples' % name))
+    else:
+        w_lst = space.listview(w_op_flags)
+        if space.isinstance_w(w_lst[0], space.w_tuple) or \
+           space.isinstance_w(w_lst[0], space.w_list):
+            if len(w_lst) != n:
+                raise OperationError(space.w_ValueError, space.wrap(
+                   '%s must be a tuple or array of per-op flag-tuples' % name))
+            for item in space.listview(w_lst):
+                ret.append(parse_one_arg(space, item))
+        else:
+            op_flag = parse_one_arg(space, w_lst)
+            for i in range(n):
+                ret.append(op_flag)
+    return ret
 
-def handle_sequence_args(space, cls, w_seq, w_op_flags, w_op_types, w_op_axes):
-    '''
-    Make sure that len(args) == 1 or len(w_seq)
-    and set attribs on cls appropriately
-    '''
-    raise OperationError(space.w_NotImplementedError, space.wrap(
-        'not implemented yet'))
+class OpFlag(object):
+    def __init__(self):
+        self.rw = 'r'
+        self.broadcast = True
+        self.force_contig = False
+        self.force_align = False
+        self.native_byte_order = False
+        self.tmp_copy = ''
+        self.allocate = False
+        self.get_it_item = get_readonly_item
+
+def get_readonly_item(space, it):
+    return space.wrap(it.getitem())
+
+def get_readwrite_item(space, it):
+    res = W_NDimArray.from_shape([1], it.dtype, it.array.order)
+    it.dtype.setitem(res.implementation, 0, it.getitem())
+    return res
+
+def parse_op_flag(space, lst):
+    op_flag = OpFlag()
+    for w_item in lst:
+        item = space.str_w(w_item)
+        if item == 'readonly':
+            op_flag.rw = 'r'
+        elif item == 'readwrite':
+            op_flag.rw = 'rw'
+        elif item == 'writeonly':
+            op_flag.rw = 'w'
+        elif item == 'no_broadcast':
+            op_flag.broadcast = False
+        elif item == 'contig':
+            op_flag.force_contig = True
+        elif item == 'aligned':
+            op_flag.force_align = True
+        elif item == 'nbo':
+            op_flag.native_byte_order = True
+        elif item == 'copy':
+            op_flag.tmp_copy = 'r'
+        elif item == 'updateifcopy':
+            op_flag.tmp_copy = 'rw'
+        elif item == 'allocate':
+            op_flag.allocate = True
+        elif item == 'no_subtype':
+            raise OperationError(space.w_NotImplementedError, space.wrap(
+                    '"no_subtype" op_flag not implemented yet'))
+        elif item == 'arraymask':
+            raise OperationError(space.w_NotImplementedError, space.wrap(
+                    '"arraymask" op_flag not implemented yet'))
+        elif item == 'writemask':
+            raise OperationError(space.w_NotImplementedError, space.wrap(
+                    '"writemask" op_flag not implemented yet'))
+        else:
+            raise OperationError(space.w_ValueError, space.wrap(
+                    'op_flags must be a tuple or array of per-op flag-tuples'))
+        if op_flag.rw == 'r':
+            op_flag.get_it_item = get_readonly_item
+        elif op_flag.rw == 'rw':
+            op_flag.get_it_item = get_readwrite_item
+    return op_flag
+
+def get_iter(space, order, imp, backward):
+    if order == 'K' or (order == 'C' and imp.order == 'C'):
+        backward = False
+    elif order =='F' and imp.order == 'C':
+        backward = True
+    else:
+        raise OperationError(space.w_NotImplementedError, space.wrap(
+                'not implemented yet'))
+    if (imp.strides[0] < imp.strides[-1] and not backward) or \
+       (imp.strides[0] > imp.strides[-1] and backward):
+        # flip the strides. Is this always true for multidimension?
+        strides = [s for s in imp.strides[::-1]]
+        backstrides = [s for s in imp.backstrides[::-1]]
+        shape = [s for s in imp.shape[::-1]]
+    else:
+        strides = imp.strides
+        backstrides = imp.backstrides
+        shape = imp.shape
+    shape1d = [support.product(imp.shape),]
+    r = calculate_broadcast_strides(strides, backstrides, shape,
+                                    shape1d, backward)
+    return MultiDimViewIterator(imp, imp.dtype, imp.start, r[0], r[1], shape)
 
 
 class W_NDIter(W_Root):
@@ -22,33 +121,18 @@
     def __init__(self, space, w_seq, w_flags, w_op_flags, w_op_dtypes, 
w_casting,
             w_op_axes, w_itershape, w_buffersize, order):
         self.order = order
-        if space.isinstance_w(w_seq, space.w_tuple) or 
space.isinstance_w(w_seq, space.w_list):
-            handle_sequence_args(space, self, w_seq, w_op_flags, w_op_dtypes, 
w_op_axes)
+        if space.isinstance_w(w_seq, space.w_tuple) or \
+           space.isinstance_w(w_seq, space.w_list):
+            w_seq_as_list = space.listview(w_seq)
+            self.seq = [convert_to_array(space, w_elem) for w_elem in 
w_seq_as_list]
         else:
             self.seq =[convert_to_array(space, w_seq)]
-            if order == 'K' or (order == 'C' and self.seq[0].get_order() == 
'C'):
-                backward = False
-            elif order =='F' and self.seq[0].get_order() == 'C':
-                backward = True
-            else:
-                raise OperationError(space.w_NotImplementedError, space.wrap(
-                        'not implemented yet'))
-            imp = self.seq[0].implementation
-            if (imp.strides[0] < imp.strides[-1] and not backward) or \
-               (imp.strides[0] > imp.strides[-1] and backward):
-                # flip the strides. Is this always true for multidimension?
-                strides = [s for s in imp.strides[::-1]]
-                backstrides = [s for s in imp.backstrides[::-1]]
-                shape = [s for s in imp.shape[::-1]]
-            else:
-                strides = imp.strides
-                backstrides = imp.backstrides
-                shape = imp.shape
-            shape1d = [support.product(imp.shape),]
-            r = calculate_broadcast_strides(strides, backstrides, shape,
-                                            shape1d, backward)
-            self.iters = [MultiDimViewIterator(imp, imp.dtype, imp.start, 
r[0], r[1],
-                            shape)]
+        self.op_flags = parse_op_arg(space, 'op_flags', w_op_flags,
+                                     len(self.seq), parse_op_flag)
+        self.iters=[]
+        for i in range(len(self.seq)):
+            self.iters.append(get_iter(space, self.order,
+                            self.seq[i].implementation, self.op_flags[i]))
 
     def descr_iter(self, space):
         return space.wrap(self)
@@ -72,9 +156,9 @@
         else:
             raise OperationError(space.w_StopIteration, space.w_None)
         res = []
-        for it in self.iters:
-            res.append(space.wrap(it.getitem()))
-            it.next()
+        for i in range(len(self.iters)):
+            res.append(self.op_flags[i].get_it_item(space, self.iters[i]))
+            self.iters[i].next()
         if len(res) <2:
             return res[0]
         return space.newtuple(res)
diff --git a/pypy/module/micronumpy/test/test_nditer.py 
b/pypy/module/micronumpy/test/test_nditer.py
--- a/pypy/module/micronumpy/test/test_nditer.py
+++ b/pypy/module/micronumpy/test/test_nditer.py
@@ -39,6 +39,7 @@
         from numpypy import arange, nditer
         a = arange(6).reshape(2,3)
         for x in nditer(a, op_flags=['readwrite']):
+            print x,x.shape
             x[...] = 2 * x
         assert (a == [[0, 2, 4], [6, 8, 10]]).all()
 
_______________________________________________
pypy-commit mailing list
[email protected]
http://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to