Author: Ilya Osadchiy <osadchiy.i...@gmail.com> Branch: numpy-multidim-exp Changeset: r45090:b2dc68ec3a1a Date: 2011-06-21 22:28 +0300 http://bitbucket.org/pypy/pypy/changeset/b2dc68ec3a1a/
Log: numpy: something on multidimensions diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py --- a/pypy/module/micronumpy/interp_numarray.py +++ b/pypy/module/micronumpy/interp_numarray.py @@ -83,12 +83,31 @@ def descr_len(self, space): return self.get_concrete().descr_len(space) + def subscript_to_index(subscript, shape): + # TODO: is it better to store cumulative multiply of shape and then index = reduce("add", map("mul", subscript, cummult_shape)) ? + index = 0 + stride = 1 + for ind, size in zip(subscript, shape): + index += ind * stride + stride *= size + def descr_getitem(self, space, w_idx): - # TODO: indexing by tuples - start, stop, step, slice_length = space.decode_index4(w_idx, self.find_size()) - if step == 0: - # Single index - return space.wrap(self.get_concrete().getitem(start)) + if space.is_true(space.isinstance(w_idx, space.w_tuple)): + # TODO: slices inside tuples, incomplete ind etc + subscript = space.unpacktuple(w_idx) + shape = self.find_shape() + if len(subscript) == len(shape): + # Fully qualified index + idx = subscript_to_index(subscript, shape) + is_single_elem = True + else: + start, stop, step, slice_length = space.decode_index4(w_idx, self.find_size()) + idx = start + is_single_elem = (step == 0) + + if is_single_elem: + # Single element + return space.wrap(self.get_concrete().getitem(idx)) else: # Slice res = SingleDimSlice(start, stop, step, slice_length, self, self.signature.transition(SingleDimSlice.static_signature)) @@ -110,6 +129,9 @@ BaseArray.__init__(self) self.float_value = float_value + def find_shape(self): + raise ValueError + def find_size(self): raise ValueError @@ -120,6 +142,7 @@ """ Class for representing virtual arrays, such as binary ops or ufuncs """ + _immutable_fields_ = ["shape"] def __init__(self, signature): BaseArray.__init__(self) self.forced_result = None @@ -133,7 +156,11 @@ i = 0 signature = self.signature result_size = self.find_size() - result = SingleDimArray(result_size) + result_shape = self.find_shape() + if len(result_shape) == 1: + result = SingleDimArray(result_size) + else: + result = MultiDimArray(result_size) while i < result_size: numpy_driver.jit_merge_point(signature=signature, result_size=result_size, i=i, @@ -156,13 +183,18 @@ return self.forced_result.eval(i) return self._eval(i) + def find_shape(self): + if self.forced_result is not None: + # The result has been computed and sources may be unavailable + return self.forced_result.find_shape() + return self._find_shape() + def find_size(self): if self.forced_result is not None: # The result has been computed and sources may be unavailable return self.forced_result.find_size() return self._find_size() - class Call1(VirtualArray): _immutable_fields_ = ["function", "values"] @@ -174,6 +206,9 @@ def _del_sources(self): self.values = None + def _find_shape(self): + return self.values.find_shape() + def _find_size(self): return self.values.find_size() @@ -195,6 +230,13 @@ self.left = None self.right = None + def _find_shape(self): + try: + return self.left.find_shape() + except ValueError: + pass + return self.right.find_shape() + def _find_size(self): try: return self.left.find_size() @@ -247,6 +289,9 @@ self.step = step self.size = slice_length + def find_shape(self): + return (self.size,) + def find_size(self): return self.size @@ -254,7 +299,10 @@ return (self.start + item * self.step) -class SingleDimArray(BaseArray): +class ConcreteArray(BaseArray): + """ + Class for array arrays that actually store data + """ signature = Signature() def __init__(self, size): @@ -273,6 +321,19 @@ def eval(self, i): return self.storage[i] + def getitem(self, item): + return self.storage[item] + + def __del__(self): + lltype.free(self.storage, flavor='raw') + +class SingleDimArray(ConcreteArray): + def __init__(self, size): + ConcreteArray.__init__(self, size) + + def find_shape(self): + return (self.size,) + def getindex(self, space, item): if item >= self.size: raise operationerrfmt(space.w_IndexError, @@ -287,17 +348,28 @@ def descr_len(self, space): return space.wrap(self.size) - def getitem(self, item): - return self.storage[item] - @unwrap_spec(item=int, value=float) def descr_setitem(self, space, item, value): item = self.getindex(space, item) self.invalidated() self.storage[item] = value - def __del__(self): - lltype.free(self.storage, flavor='raw') +class MultiDimArray(ConcreteArray): + _immutable_fields_ = ["shape"] + def __init__(self, size, shape): + ConcreteArray.__init__(self, size) + self.shape = shape + + def find_shape(self): + return self.shape + + def descr_len(self, space): + return space.wrap(self.shape(0)) + + def descr_setitem(self, space, w_subscript, w_value): + item = self.getindex(space, item) + self.invalidated() + self.storage[item] = value def descr_new_numarray(space, w_type, w_size_or_iterable): l = space.listview(w_size_or_iterable) @@ -308,10 +380,16 @@ i += 1 return space.wrap(arr) -@unwrap_spec(ObjSpace, int) -def zeros(space, size): - return space.wrap(SingleDimArray(size)) - +#@unwrap_spec(ObjSpace, int) +def zeros(space, w_size): + if space.is_true(space.isinstance(w_size, space.w_tuple)): + shape = tuple(space.unpackiterable(w_size)) + size = reduce(lambda x, y: x*y, shape) + return space.wrap(MultiDimArray(size, shape)) + elif space.is_true(space.isinstance(w_size, space.w_int)): + return space.wrap(SingleDimArray(space.int_w(w_size))) + else: + raise OperationError(space.w_TypeError, space.wrap("expected sequence object with len >= 0")) BaseArray.typedef = TypeDef( 'numarray', diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py --- a/pypy/module/micronumpy/test/test_numarray.py +++ b/pypy/module/micronumpy/test/test_numarray.py @@ -175,7 +175,6 @@ a[2] = 20 assert s[2] == 20 - def test_slice_invaidate(self): # check that slice shares invalidation list with from numpy import array _______________________________________________ pypy-commit mailing list pypy-commit@python.org http://mail.python.org/mailman/listinfo/pypy-commit