Author: Maciej Fijalkowski <fij...@gmail.com> Branch: numpy-indexing-by-arrays-2 Changeset: r51501:c5926e6a02bc Date: 2012-01-19 21:31 +0200 http://bitbucket.org/pypy/pypy/changeset/c5926e6a02bc/
Log: progress; diff --git a/pypy/module/micronumpy/interp_iter.py b/pypy/module/micronumpy/interp_iter.py --- a/pypy/module/micronumpy/interp_iter.py +++ b/pypy/module/micronumpy/interp_iter.py @@ -24,6 +24,9 @@ def __init__(self, arr): self.arr = arr.get_concrete() + def extend_shape(self, shape): + shape.extend(self.arr.shape) + class BoolArrayChunk(BaseChunk): def __init__(self, arr): self.arr = arr.get_concrete() diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py --- a/pypy/module/micronumpy/interp_numarray.py +++ b/pypy/module/micronumpy/interp_numarray.py @@ -10,7 +10,7 @@ from pypy.tool.sourcetools import func_with_new_name from pypy.rlib.rstring import StringBuilder from pypy.module.micronumpy.interp_iter import ArrayIterator, OneDimIterator,\ - SkipLastAxisIterator, Chunk, ViewIterator + SkipLastAxisIterator, Chunk, ViewIterator, BoolArrayChunk, IntArrayChunk numpy_driver = jit.JitDriver( greens=['shapelen', 'sig'], @@ -229,6 +229,17 @@ n_old_elems_to_use *= old_shape[oldI] return new_strides +def wrap_chunk(space, w_idx, size): + if (space.isinstance_w(w_idx, space.w_int) or + space.isinstance_w(w_idx, space.w_slice)): + return Chunk(*space.decode_index4(w_idx, size)) + arr = convert_to_array(space, w_idx) + if arr.find_dtype().is_bool_type(): + return BoolArrayChunk(arr) + elif arr.find_dtype().is_int_type(): + return IntArrayChunk(arr) + raise OperationError(space.w_IndexError, space.wrap("arrays used as indices must be of integer (or boolean) type")) + class BaseArray(Wrappable): _attrs_ = ["invalidates", "shape", 'size'] @@ -485,6 +496,8 @@ elif (space.isinstance_w(w_idx, space.w_slice) or space.isinstance_w(w_idx, space.w_int)): return False + if isinstance(w_idx, BaseArray): + return False lgt = space.len_w(w_idx) if lgt > shape_len: raise OperationError(space.w_IndexError, @@ -494,14 +507,15 @@ for w_item in space.fixedview(w_idx): if space.isinstance_w(w_item, space.w_slice): return False + if isinstance(w_item, BaseArray): + return False return True @jit.unroll_safe def _prepare_slice_args(self, space, w_idx): - if (space.isinstance_w(w_idx, space.w_int) or - space.isinstance_w(w_idx, space.w_slice)): - return [Chunk(*space.decode_index4(w_idx, self.shape[0]))] - return [Chunk(*space.decode_index4(w_item, self.shape[i])) for i, w_item in + if not space.isinstance_w(w_idx, space.w_tuple): + return [wrap_chunk(space, w_idx, self.shape[0])] + return [wrap_chunk(space, w_item, self.shape[i]) for i, w_item in enumerate(space.fixedview(w_idx))] def count_all_true(self, arr): @@ -563,6 +577,7 @@ if (isinstance(w_idx, BaseArray) and w_idx.shape == self.shape and w_idx.find_dtype().is_bool_type()): return self.getitem_filter(space, w_idx) + # XXX deal with a scalar if self._single_item_result(space, w_idx): concrete = self.get_concrete() item = concrete._index_of_single_item(space, w_idx) @@ -589,6 +604,13 @@ view = self.create_slice(chunks).get_concrete() view.setslice(space, w_value) + def force_slice(self, shape, chunks): + size = 1 + for elem in shape: + size *= elem + res = W_NDimArray(size, shape, self.find_dtype()) + xxx + @jit.unroll_safe def create_slice(self, chunks): shape = [] @@ -598,6 +620,9 @@ s = i + 1 assert s >= 0 shape += self.shape[s:] + for chunk in chunks: + if not isinstance(chunk, Chunk): + return self.force_slice(shape, chunks) if not isinstance(self, ConcreteArray): return VirtualSlice(self, chunks, shape) r = calculate_slice_strides(self.shape, self.start, self.strides, diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py --- a/pypy/module/micronumpy/test/test_numarray.py +++ b/pypy/module/micronumpy/test/test_numarray.py @@ -1301,16 +1301,13 @@ raises(TypeError, getattr, array(3), '__array_interface__') def test_array_indexing_one_elem(self): - skip("not yet") from _numpypy import array, arange raises(IndexError, 'arange(3)[array([3.5])]') a = arange(3)[array([1])] - assert a == 1 - assert a[0] == 1 + assert a == [1] raises(IndexError,'arange(3)[array([15])]') assert arange(3)[array([-3])] == 0 raises(IndexError,'arange(3)[array([-15])]') - assert arange(3)[array(1)] == 1 def test_array_indexing_bool(self): from _numpypy import arange _______________________________________________ pypy-commit mailing list pypy-commit@python.org http://mail.python.org/mailman/listinfo/pypy-commit