Author: mattip Branch: numppy-flatitter Changeset: r51670:ff368a6d0ff7 Date: 2012-01-19 01:16 +0200 http://bitbucket.org/pypy/pypy/changeset/ff368a6d0ff7/
Log: redo, add lots of tests, some still fail diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py --- a/pypy/module/micronumpy/interp_numarray.py +++ b/pypy/module/micronumpy/interp_numarray.py @@ -1290,53 +1290,100 @@ tolist = interp2app(BaseArray.descr_tolist), ) +#TODO:Move all this to another file after fijal finishes reorganization +def _to_coords(space, arr, w_item_or_slice): + '''Always returns a list of coords. + ''' + start = stop = step = 0 + if space.isinstance_w(w_item_or_slice, space.w_int): + start = space.int_w(w_item_or_slice) + if start < 0: + start += arr.size + stop = start+1 + step = 1 + elif space.isinstance_w(w_item_or_slice, space.w_slice): + start, stop, step, lngth = space.decode_index4(w_item_or_slice,arr.size) + else: + operationerrfmt(space.NotImplementedError,'cannot iterator over %s yet', w_item_or_slice) + retval = [] + for i in range(start, stop, step): + coords = [] + ii = i + if arr.order =='C': + for s in range(len(arr.shape) -1, -1, -1): + coords.insert(0,ii % arr.shape[s]) + ii /= arr.shape[s] + else: + raise NotImplementedError + #untested code. Erase? + for s in range(len(arr.shape)): + coords.append(ii % arr.shape[s]) + ii /= arr.shape[s] + if ii != 0: + raise OperationError(space.w_IndexError, + space.wrap("invalid index")) + + retval.append(space.newtuple([space.wrap(c) for c in coords])) + return retval class W_FlatIterator(ViewArray): @jit.unroll_safe def __init__(self, arr): arr = arr.get_concrete() - size = 1 - for sh in arr.shape: - size *= sh - if arr.strides[-1] <= arr.strides[0]: - self.strides = [arr.strides[-1]] - self.backstrides = [arr.backstrides[-1]] - else: - XXX - # This will not work: getitem and setitem will - # fail. Need to be smarter: calculate the indices from the int - self.strides = [arr.strides[0]] - self.backstrides = [arr.backstrides[0]] - ViewArray.__init__(self, size, [size], arr.dtype, order=arr.order, - parent=arr) self.shapelen = len(arr.shape) sig = arr.find_sig() - #self.iter = OneDimIterator(arr.start, self.strides[0], - # self.shape[0]) self.iter = sig.create_frame(arr).get_final_iter() - self.start = arr.start self.base = arr + self.index = 0 + ViewArray.__init__(self, arr.size, [arr.size], arr.dtype, arr.order, arr) def descr_next(self, space): if self.iter.done(): raise OperationError(space.w_StopIteration, space.w_None) - result = self.getitem(self.iter.offset) + result = self.base.getitem(self.iter.offset) self.iter = self.iter.next(self.shapelen) + self.index += 1 return result def descr_iter(self): return self def descr_index(self, space): - return space.wrap(self.iter.offset) + return space.wrap(self.index) + def descr_coords(self, space): + return _to_coords(space, self.base, space.wrap(self.index))[0] + + def descr_getitem(self, space, w_idx): + coords = _to_coords(space, self.base, w_idx) + if len(coords)>1: + w_result = W_NDimArray(len(coords), [len(coords)], self.base.dtype, + self.base.order) + for i,c in enumerate(coords): + w_val = self.base.descr_getitem(space, c) + w_result.setitem(i,w_val) + return w_result + else: + return self.base.descr_getitem(space, coords[0]) + + def descr_setitem(self, space, w_idx, w_value): + coords = _to_coords(space, self.base, w_idx) + arr = convert_to_array(space, w_value) + ai = 0 + for c in coords: + v = arr.getitem(ai) + self.base.descr_setitem(space, c,v) + ai = (ai + 1) % arr.size + W_FlatIterator.typedef = TypeDef( 'flatiter', next = interp2app(W_FlatIterator.descr_next), __iter__ = interp2app(W_FlatIterator.descr_iter), - __getitem__ = interp2app(BaseArray.descr_getitem), - __setitem__ = interp2app(BaseArray.descr_setitem), + __getitem__ = interp2app(W_FlatIterator.descr_getitem), + __setitem__ = interp2app(W_FlatIterator.descr_setitem), index = GetSetProperty(W_FlatIterator.descr_index), + coords = GetSetProperty(W_FlatIterator.descr_coords), + ) W_FlatIterator.acceptable_as_base_class = False diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py --- a/pypy/module/micronumpy/test/test_numarray.py +++ b/pypy/module/micronumpy/test/test_numarray.py @@ -1300,7 +1300,7 @@ assert (a + a).flat[3] == 6 assert a[::2].flat[3] == 6 assert a.reshape(2,5).flat[3] == 3 - b = a.flat + b = a.reshape(2,5).flat b.next() b.next() b.next() @@ -1309,6 +1309,19 @@ raises(IndexError, "b[11]") raises(IndexError, "b[-11]") assert b.index == 3 + assert b.coords == (0,3) + + def test_flatiter_setitem(self): + from _numpypy import arange, array + a = arange(12).reshape(3,4) + b = a.T.flat + b[6::2] = [-1, -2] + assert (a == [[0, 1, -1, 3], [4, 5, 6, -1], [8, 9, -2, 11]]).all() + b[1:2] = [[[100]]] + assert(a[0,0] == 100) + assert(a[1,0] == 100) + b[array([10, 11])] == [-20, -40] + def test_flatiter_view(self): from _numpypy import arange @@ -1323,8 +1336,14 @@ def test_flatiter_transpose(self): from _numpypy import arange - a = arange(10) - assert a.reshape(2,5).T.flat[3] == 6 + a = arange(10).reshape(2,5).T + b = a.flat + assert (b[:5] == [0, 5, 1, 6, 2]).all() + b.next() + b.next() + b.next() + assert b.index == 3 + assert b.coords == (1,1) def test_slice_copy(self): from _numpypy import zeros _______________________________________________ pypy-commit mailing list pypy-commit@python.org http://mail.python.org/mailman/listinfo/pypy-commit