Author: Ilya Osadchiy <osadchiy.i...@gmail.com> Branch: numpy-indexing-by-arrays Changeset: r47114:7e7b4f1c2c5c Date: 2011-09-05 22:24 +0300 http://bitbucket.org/pypy/pypy/changeset/7e7b4f1c2c5c/
Log: Initial implementation diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py --- a/pypy/module/micronumpy/interp_numarray.py +++ b/pypy/module/micronumpy/interp_numarray.py @@ -217,7 +217,6 @@ return space.wrap("[" + " ".join(concrete._getnums(True)) + "]") def descr_getitem(self, space, w_idx): - # TODO: indexing by arrays and lists if space.isinstance_w(w_idx, space.w_tuple): length = space.len_w(w_idx) if length == 0: @@ -226,6 +225,24 @@ raise OperationError(space.w_IndexError, space.wrap("invalid index")) w_idx = space.getitem(w_idx, space.wrap(0)) + elif space.issequence_w(w_idx): + w_idx = convert_to_array(space, w_idx) + bool_dtype = space.fromcache(interp_dtype.W_BoolDtype) + int_dtype = space.fromcache(interp_dtype.W_Int64Dtype) + if w_idx.find_dtype() is bool_dtype: + # TODO: indexing by bool array + raise NotImplementedError("sorry, not yet implemented") + else: + # Indexing by array + + # FIXME: should raise exception if any index in + # array is out od bound, but this kills lazy execution + new_sig = signature.Signature.find_sig([ + IndexedByArray.signature, self.signature + ]) + res = IndexedByArray(new_sig, int_dtype, self, w_idx) + return space.wrap(res) + start, stop, step, slice_length = space.decode_index4(w_idx, self.find_size()) if step == 0: # Single index @@ -430,6 +447,29 @@ assert isinstance(call_sig, signature.Call2) return call_sig.func(self.res_dtype, lhs, rhs) +class IndexedByArray(VirtualArray): + """ + Intermediate class for performing indexing of array by another array + """ + signature = signature.BaseSignature() + def __init__(self, signature, int_dtype, source, index): + VirtualArray.__init__(self, signature, source.find_dtype()) + self.source = source + self.index = index + self.int_dtype = int_dtype + + def _del_sources(self): + self.source = None + self.index = None + + def _find_size(self): + return self.index.find_size() + + def _eval(self, i): + idx = self.int_dtype.unbox(self.index.eval(i).convert_to(self.int_dtype)) + val = self.source.eval(idx).convert_to(self.res_dtype) + return val + class ViewArray(BaseArray): """ Class for representing views of arrays, they will reflect changes of parent diff --git a/pypy/module/micronumpy/interp_ufuncs.py b/pypy/module/micronumpy/interp_ufuncs.py --- a/pypy/module/micronumpy/interp_ufuncs.py +++ b/pypy/module/micronumpy/interp_ufuncs.py @@ -200,7 +200,7 @@ int64_dtype = space.fromcache(interp_dtype.W_Int64Dtype) if space.is_w(w_type, space.w_bool): - if current_guess is None: + if current_guess is None or current_guess is bool_dtype: return bool_dtype elif space.is_w(w_type, space.w_int): if (current_guess is None or current_guess is bool_dtype or @@ -270,4 +270,4 @@ setattr(self, ufunc_name, ufunc) def get(space): - return space.fromcache(UfuncState) \ No newline at end of file + return space.fromcache(UfuncState) diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py --- a/pypy/module/micronumpy/test/test_numarray.py +++ b/pypy/module/micronumpy/test/test_numarray.py @@ -119,6 +119,20 @@ for i in xrange(5): assert a[i] == b[i] + def test_index_by_array(self): + from numpy import array + a = array(range(5)) + idx_list = [3, 1, 3, 2, 0, 4] + idx_arr = array(idx_list) + a_by_arr = a[idx_arr] + assert len(a_by_arr) == 6 + for i in xrange(6): + assert a_by_arr[i] == range(5)[idx_list[i]] + a_by_list = a[idx_list] + assert len(a_by_list) == 6 + for i in xrange(6): + assert a_by_list[i] == range(5)[idx_list[i]] + def test_setitem(self): from numpy import array a = array(range(5)) _______________________________________________ pypy-commit mailing list pypy-commit@python.org http://mail.python.org/mailman/listinfo/pypy-commit