Author: Maciej Fijalkowski <[email protected]>
Branch: numpy-refactor
Changeset: r57188:67acfa0c2475
Date: 2012-09-06 21:30 +0200
http://bitbucket.org/pypy/pypy/changeset/67acfa0c2475/
Log: setiter
diff --git a/pypy/module/micronumpy/arrayimpl/concrete.py
b/pypy/module/micronumpy/arrayimpl/concrete.py
--- a/pypy/module/micronumpy/arrayimpl/concrete.py
+++ b/pypy/module/micronumpy/arrayimpl/concrete.py
@@ -35,6 +35,9 @@
def done(self):
return self.offset >= self.size
+ def reset(self):
+ self.offset %= self.size
+
class OneDimViewIterator(ConcreteArrayIterator):
def __init__(self, array):
self.array = array
@@ -55,6 +58,9 @@
def done(self):
return self.index >= self.size
+ def reset(self):
+ self.offset %= self.size
+
class MultiDimViewIterator(ConcreteArrayIterator):
def __init__(self, array, start, strides, backstrides, shape):
self.indexes = [0] * len(shape)
@@ -65,6 +71,7 @@
self._done = False
self.strides = strides
self.backstrides = backstrides
+ self.size = array.size
@jit.unroll_safe
def next(self):
@@ -100,6 +107,9 @@
def done(self):
return self._done
+ def reset(self):
+ self.offset %= self.size
+
class AxisIterator(base.BaseArrayIterator):
def __init__(self, array, shape, dim):
self.shape = shape
diff --git a/pypy/module/micronumpy/arrayimpl/scalar.py
b/pypy/module/micronumpy/arrayimpl/scalar.py
--- a/pypy/module/micronumpy/arrayimpl/scalar.py
+++ b/pypy/module/micronumpy/arrayimpl/scalar.py
@@ -20,6 +20,9 @@
def done(self):
return False
+ def reset(self):
+ pass
+
class Scalar(base.BaseArrayImplementation):
def __init__(self, dtype, value=None):
self.value = value
diff --git a/pypy/module/micronumpy/interp_flatiter.py
b/pypy/module/micronumpy/interp_flatiter.py
--- a/pypy/module/micronumpy/interp_flatiter.py
+++ b/pypy/module/micronumpy/interp_flatiter.py
@@ -1,5 +1,5 @@
-from pypy.module.micronumpy.base import W_NDimArray
+from pypy.module.micronumpy.base import W_NDimArray, convert_to_array
from pypy.module.micronumpy import loop
from pypy.module.micronumpy.strides import to_coords
from pypy.interpreter.baseobjspace import Wrappable
@@ -43,8 +43,6 @@
self.reset()
base = self.base
start, stop, step, length = space.decode_index4(w_idx, base.get_size())
- # setslice would have been better, but flat[u:v] for arbitrary
- # shapes of array a cannot be represented as a[x1:x2, y1:y2]
base_iter = base.create_iter()
base_iter.next_skip_x(start)
if length == 1:
@@ -53,6 +51,16 @@
base.get_order())
return loop.flatiter_getitem(res, base_iter, step)
+ def descr_setitem(self, space, w_idx, w_value):
+ if not (space.isinstance_w(w_idx, space.w_int) or
+ space.isinstance_w(w_idx, space.w_slice)):
+ raise OperationError(space.w_IndexError,
+ space.wrap('unsupported iterator index'))
+ base = self.base
+ start, stop, step, length = space.decode_index4(w_idx, base.get_size())
+ arr = convert_to_array(space, w_value)
+ loop.flatiter_setitem(self.base, arr, start, step, length)
+
def descr_iter(self):
return self
@@ -63,6 +71,7 @@
'flatiter',
__iter__ = interp2app(W_FlatIterator.descr_iter),
__getitem__ = interp2app(W_FlatIterator.descr_getitem),
+ __setitem__ = interp2app(W_FlatIterator.descr_setitem),
__len__ = interp2app(W_FlatIterator.descr_len),
next = interp2app(W_FlatIterator.descr_next),
diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -194,3 +194,17 @@
base_iter.next_skip_x(step)
ri.next()
return res
+
+def flatiter_setitem(arr, val, start, step, length):
+ dtype = arr.get_dtype()
+ arr_iter = arr.create_iter()
+ val_iter = val.create_iter()
+ arr_iter.next_skip_x(start)
+ while length > 0:
+ arr_iter.setitem(val_iter.getitem().convert_to(dtype))
+ # need to repeat i_nput values until all assignments are done
+ arr_iter.next_skip_x(step)
+ length -= 1
+ val_iter.next()
+ # WTF numpy?
+ val_iter.reset()
diff --git a/pypy/module/micronumpy/test/test_numarray.py
b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -1831,7 +1831,6 @@
a = arange(12).reshape(3,4)
b = a.T.flat
b[6::2] = [-1, -2]
- print a == [[0, 1, -1, 3], [4, 5, 6, -1], [8, 9, -2, 11]]
assert (a == [[0, 1, -1, 3], [4, 5, 6, -1], [8, 9, -2, 11]]).all()
b[0:2] = [[[100]]]
assert(a[0,0] == 100)
_______________________________________________
pypy-commit mailing list
[email protected]
http://mail.python.org/mailman/listinfo/pypy-commit