Author: Christian Tismer <tis...@stackless.com> Branch: Changeset: r53767:c29aa5160f63 Date: 2012-03-18 01:13 +0100 http://bitbucket.org/pypy/pypy/changeset/c29aa5160f63/
Log: Merge diff --git a/lib_pypy/numpypy/core/numeric.py b/lib_pypy/numpypy/core/numeric.py --- a/lib_pypy/numpypy/core/numeric.py +++ b/lib_pypy/numpypy/core/numeric.py @@ -6,7 +6,7 @@ import _numpypy as multiarray # ARGH from numpypy.core.arrayprint import array2string - +newaxis = None def asanyarray(a, dtype=None, order=None, maskna=None, ownmaskna=False): """ @@ -319,4 +319,4 @@ False_ = bool_(False) True_ = bool_(True) e = math.e -pi = math.pi \ No newline at end of file +pi = math.pi diff --git a/pypy/module/micronumpy/interp_iter.py b/pypy/module/micronumpy/interp_iter.py --- a/pypy/module/micronumpy/interp_iter.py +++ b/pypy/module/micronumpy/interp_iter.py @@ -50,6 +50,7 @@ # structures to describe slicing class Chunk(object): + axis_step = 1 def __init__(self, start, stop, step, lgt): self.start = start self.stop = stop @@ -64,6 +65,16 @@ return 'Chunk(%d, %d, %d, %d)' % (self.start, self.stop, self.step, self.lgt) +class NewAxisChunk(Chunk): + start = 0 + stop = 1 + step = 1 + lgt = 1 + axis_step = 0 + + def __init__(self): + pass + class BaseTransform(object): pass diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py --- a/pypy/module/micronumpy/interp_numarray.py +++ b/pypy/module/micronumpy/interp_numarray.py @@ -7,10 +7,10 @@ from pypy.module.micronumpy.appbridge import get_appbridge_cache from pypy.module.micronumpy.dot import multidim_dot, match_dot_shapes from pypy.module.micronumpy.interp_iter import (ArrayIterator, - SkipLastAxisIterator, Chunk, ViewIterator) + SkipLastAxisIterator, Chunk, NewAxisChunk, ViewIterator) from pypy.module.micronumpy.strides import (calculate_slice_strides, shape_agreement, find_shape_and_elems, get_shape_from_iterable, - calc_new_strides, to_coords) + calc_new_strides, to_coords, enumerate_chunks) from pypy.rlib import jit from pypy.rlib.rstring import StringBuilder from pypy.rpython.lltypesystem import lltype, rffi @@ -321,6 +321,13 @@ is a list of scalars that match the size of shape """ shape_len = len(self.shape) + if space.isinstance_w(w_idx, space.w_tuple): + for w_item in space.fixedview(w_idx): + if (space.isinstance_w(w_item, space.w_slice) or + space.isinstance_w(w_item, space.w_NoneType)): + return False + elif space.isinstance_w(w_idx, space.w_NoneType): + return False if shape_len == 0: raise OperationError(space.w_IndexError, space.wrap( "0-d arrays can't be indexed")) @@ -336,20 +343,25 @@ if lgt > shape_len: raise OperationError(space.w_IndexError, space.wrap("invalid index")) - if lgt < shape_len: - return False - for w_item in space.fixedview(w_idx): - if space.isinstance_w(w_item, space.w_slice): - return False - return True + return lgt == shape_len @jit.unroll_safe def _prepare_slice_args(self, space, w_idx): if (space.isinstance_w(w_idx, space.w_int) or space.isinstance_w(w_idx, space.w_slice)): return [Chunk(*space.decode_index4(w_idx, self.shape[0]))] - return [Chunk(*space.decode_index4(w_item, self.shape[i])) for i, w_item in - enumerate(space.fixedview(w_idx))] + elif space.isinstance_w(w_idx, space.w_NoneType): + return [NewAxisChunk()] + result = [] + i = 0 + for w_item in space.fixedview(w_idx): + if space.isinstance_w(w_item, space.w_NoneType): + result.append(NewAxisChunk()) + else: + result.append(Chunk(*space.decode_index4(w_item, + self.shape[i]))) + i += 1 + return result def count_all_true(self, arr): sig = arr.find_sig() @@ -443,7 +455,7 @@ def create_slice(self, chunks): shape = [] i = -1 - for i, chunk in enumerate(chunks): + for i, chunk in enumerate_chunks(chunks): chunk.extend_shape(shape) s = i + 1 assert s >= 0 diff --git a/pypy/module/micronumpy/strides.py b/pypy/module/micronumpy/strides.py --- a/pypy/module/micronumpy/strides.py +++ b/pypy/module/micronumpy/strides.py @@ -1,6 +1,14 @@ from pypy.rlib import jit from pypy.interpreter.error import OperationError +def enumerate_chunks(chunks): + result = [] + i = -1 + for chunk in chunks: + i += chunk.axis_step + result.append((i, chunk)) + return result + @jit.look_inside_iff(lambda shape, start, strides, backstrides, chunks: jit.isconstant(len(chunks)) ) @@ -10,7 +18,7 @@ rstart = start rshape = [] i = -1 - for i, chunk in enumerate(chunks): + for i, chunk in enumerate_chunks(chunks): if chunk.step != 0: rstrides.append(strides[i] * chunk.step) rbackstrides.append(strides[i] * (chunk.lgt - 1) * chunk.step) diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py --- a/pypy/module/micronumpy/test/test_numarray.py +++ b/pypy/module/micronumpy/test/test_numarray.py @@ -374,6 +374,57 @@ assert a[1] == 0. assert a[3] == 0. + def test_newaxis(self): + from _numpypy import array + from numpypy.core.numeric import newaxis + a = array(range(5)) + b = array([range(5)]) + assert (a[newaxis] == b).all() + + def test_newaxis_slice(self): + from _numpypy import array + from numpypy.core.numeric import newaxis + + a = array(range(5)) + b = array(range(1,5)) + c = array([range(1,5)]) + d = array([[x] for x in range(1,5)]) + + assert (a[1:] == b).all() + assert (a[1:,newaxis] == d).all() + assert (a[newaxis,1:] == c).all() + + def test_newaxis_assign(self): + from _numpypy import array + from numpypy.core.numeric import newaxis + + a = array(range(5)) + a[newaxis,1] = [2] + assert a[1] == 2 + + def test_newaxis_virtual(self): + from _numpypy import array + from numpypy.core.numeric import newaxis + + a = array(range(5)) + b = (a + a)[newaxis] + c = array([[0, 2, 4, 6, 8]]) + assert (b == c).all() + + def test_newaxis_then_slice(self): + from _numpypy import array + from numpypy.core.numeric import newaxis + a = array(range(5)) + b = a[newaxis] + assert (b[0,1:] == a[1:]).all() + + def test_slice_then_newaxis(self): + from _numpypy import array + from numpypy.core.numeric import newaxis + a = array(range(5)) + b = a[2:] + assert (b[newaxis] == [[2, 3, 4]]).all() + def test_scalar(self): from _numpypy import array, dtype a = array(3) _______________________________________________ pypy-commit mailing list pypy-commit@python.org http://mail.python.org/mailman/listinfo/pypy-commit