Author: Matti Picus <[email protected]>
Branch: numpypy-nditer
Changeset: r64066:8b31168ea449
Date: 2013-05-14 14:07 +0300
http://bitbucket.org/pypy/pypy/changeset/8b31168ea449/
Log: readwrite works
diff --git a/pypy/module/micronumpy/interp_nditer.py
b/pypy/module/micronumpy/interp_nditer.py
--- a/pypy/module/micronumpy/interp_nditer.py
+++ b/pypy/module/micronumpy/interp_nditer.py
@@ -6,6 +6,7 @@
from pypy.module.micronumpy.strides import calculate_broadcast_strides
from pypy.module.micronumpy.iter import MultiDimViewIterator
from pypy.module.micronumpy import support
+from pypy.module.micronumpy.arrayimpl.concrete import SliceArray
def parse_op_arg(space, name, w_op_flags, n, parse_one_arg):
ret = []
@@ -42,13 +43,14 @@
self.allocate = False
self.get_it_item = get_readonly_item
-def get_readonly_item(space, it):
+def get_readonly_item(space, array, it):
return space.wrap(it.getitem())
-def get_readwrite_item(space, it):
- res = W_NDimArray.from_shape([1], it.dtype, it.array.order)
- it.dtype.setitem(res.implementation, 0, it.getitem())
- return res
+def get_readwrite_item(space, array, it):
+ #create a single-value view (since scalars are not views)
+ res = SliceArray(it.array.start + it.offset, [0], [0], [1,], it.array,
array)
+ #it.dtype.setitem(res, 0, it.getitem())
+ return W_NDimArray(res)
def parse_op_flag(space, lst):
op_flag = OpFlag()
@@ -157,7 +159,8 @@
raise OperationError(space.w_StopIteration, space.w_None)
res = []
for i in range(len(self.iters)):
- res.append(self.op_flags[i].get_it_item(space, self.iters[i]))
+ res.append(self.op_flags[i].get_it_item(space, self.seq[i],
+ self.iters[i]))
self.iters[i].next()
if len(res) <2:
return res[0]
diff --git a/pypy/module/micronumpy/test/test_nditer.py
b/pypy/module/micronumpy/test/test_nditer.py
--- a/pypy/module/micronumpy/test/test_nditer.py
+++ b/pypy/module/micronumpy/test/test_nditer.py
@@ -39,7 +39,6 @@
from numpypy import arange, nditer
a = arange(6).reshape(2,3)
for x in nditer(a, op_flags=['readwrite']):
- print x,x.shape
x[...] = 2 * x
assert (a == [[0, 2, 4], [6, 8, 10]]).all()
@@ -53,6 +52,7 @@
r = []
for x in nditer(a, flags=['external_loop'], order='F'):
r.append(x)
+ print r
assert (array(r) == [[0, 3], [1, 4], [2, 5]]).all()
def test_interface(self):
_______________________________________________
pypy-commit mailing list
[email protected]
http://mail.python.org/mailman/listinfo/pypy-commit