Author: Christian Tismer <[email protected]>
Branch:
Changeset: r53767:c29aa5160f63
Date: 2012-03-18 01:13 +0100
http://bitbucket.org/pypy/pypy/changeset/c29aa5160f63/
Log: Merge
diff --git a/lib_pypy/numpypy/core/numeric.py b/lib_pypy/numpypy/core/numeric.py
--- a/lib_pypy/numpypy/core/numeric.py
+++ b/lib_pypy/numpypy/core/numeric.py
@@ -6,7 +6,7 @@
import _numpypy as multiarray # ARGH
from numpypy.core.arrayprint import array2string
-
+newaxis = None
def asanyarray(a, dtype=None, order=None, maskna=None, ownmaskna=False):
"""
@@ -319,4 +319,4 @@
False_ = bool_(False)
True_ = bool_(True)
e = math.e
-pi = math.pi
\ No newline at end of file
+pi = math.pi
diff --git a/pypy/module/micronumpy/interp_iter.py
b/pypy/module/micronumpy/interp_iter.py
--- a/pypy/module/micronumpy/interp_iter.py
+++ b/pypy/module/micronumpy/interp_iter.py
@@ -50,6 +50,7 @@
# structures to describe slicing
class Chunk(object):
+ axis_step = 1
def __init__(self, start, stop, step, lgt):
self.start = start
self.stop = stop
@@ -64,6 +65,16 @@
return 'Chunk(%d, %d, %d, %d)' % (self.start, self.stop, self.step,
self.lgt)
+class NewAxisChunk(Chunk):
+ start = 0
+ stop = 1
+ step = 1
+ lgt = 1
+ axis_step = 0
+
+ def __init__(self):
+ pass
+
class BaseTransform(object):
pass
diff --git a/pypy/module/micronumpy/interp_numarray.py
b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -7,10 +7,10 @@
from pypy.module.micronumpy.appbridge import get_appbridge_cache
from pypy.module.micronumpy.dot import multidim_dot, match_dot_shapes
from pypy.module.micronumpy.interp_iter import (ArrayIterator,
- SkipLastAxisIterator, Chunk, ViewIterator)
+ SkipLastAxisIterator, Chunk, NewAxisChunk, ViewIterator)
from pypy.module.micronumpy.strides import (calculate_slice_strides,
shape_agreement, find_shape_and_elems, get_shape_from_iterable,
- calc_new_strides, to_coords)
+ calc_new_strides, to_coords, enumerate_chunks)
from pypy.rlib import jit
from pypy.rlib.rstring import StringBuilder
from pypy.rpython.lltypesystem import lltype, rffi
@@ -321,6 +321,13 @@
is a list of scalars that match the size of shape
"""
shape_len = len(self.shape)
+ if space.isinstance_w(w_idx, space.w_tuple):
+ for w_item in space.fixedview(w_idx):
+ if (space.isinstance_w(w_item, space.w_slice) or
+ space.isinstance_w(w_item, space.w_NoneType)):
+ return False
+ elif space.isinstance_w(w_idx, space.w_NoneType):
+ return False
if shape_len == 0:
raise OperationError(space.w_IndexError, space.wrap(
"0-d arrays can't be indexed"))
@@ -336,20 +343,25 @@
if lgt > shape_len:
raise OperationError(space.w_IndexError,
space.wrap("invalid index"))
- if lgt < shape_len:
- return False
- for w_item in space.fixedview(w_idx):
- if space.isinstance_w(w_item, space.w_slice):
- return False
- return True
+ return lgt == shape_len
@jit.unroll_safe
def _prepare_slice_args(self, space, w_idx):
if (space.isinstance_w(w_idx, space.w_int) or
space.isinstance_w(w_idx, space.w_slice)):
return [Chunk(*space.decode_index4(w_idx, self.shape[0]))]
- return [Chunk(*space.decode_index4(w_item, self.shape[i])) for i,
w_item in
- enumerate(space.fixedview(w_idx))]
+ elif space.isinstance_w(w_idx, space.w_NoneType):
+ return [NewAxisChunk()]
+ result = []
+ i = 0
+ for w_item in space.fixedview(w_idx):
+ if space.isinstance_w(w_item, space.w_NoneType):
+ result.append(NewAxisChunk())
+ else:
+ result.append(Chunk(*space.decode_index4(w_item,
+ self.shape[i])))
+ i += 1
+ return result
def count_all_true(self, arr):
sig = arr.find_sig()
@@ -443,7 +455,7 @@
def create_slice(self, chunks):
shape = []
i = -1
- for i, chunk in enumerate(chunks):
+ for i, chunk in enumerate_chunks(chunks):
chunk.extend_shape(shape)
s = i + 1
assert s >= 0
diff --git a/pypy/module/micronumpy/strides.py
b/pypy/module/micronumpy/strides.py
--- a/pypy/module/micronumpy/strides.py
+++ b/pypy/module/micronumpy/strides.py
@@ -1,6 +1,14 @@
from pypy.rlib import jit
from pypy.interpreter.error import OperationError
+def enumerate_chunks(chunks):
+ result = []
+ i = -1
+ for chunk in chunks:
+ i += chunk.axis_step
+ result.append((i, chunk))
+ return result
+
@jit.look_inside_iff(lambda shape, start, strides, backstrides, chunks:
jit.isconstant(len(chunks))
)
@@ -10,7 +18,7 @@
rstart = start
rshape = []
i = -1
- for i, chunk in enumerate(chunks):
+ for i, chunk in enumerate_chunks(chunks):
if chunk.step != 0:
rstrides.append(strides[i] * chunk.step)
rbackstrides.append(strides[i] * (chunk.lgt - 1) * chunk.step)
diff --git a/pypy/module/micronumpy/test/test_numarray.py
b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -374,6 +374,57 @@
assert a[1] == 0.
assert a[3] == 0.
+ def test_newaxis(self):
+ from _numpypy import array
+ from numpypy.core.numeric import newaxis
+ a = array(range(5))
+ b = array([range(5)])
+ assert (a[newaxis] == b).all()
+
+ def test_newaxis_slice(self):
+ from _numpypy import array
+ from numpypy.core.numeric import newaxis
+
+ a = array(range(5))
+ b = array(range(1,5))
+ c = array([range(1,5)])
+ d = array([[x] for x in range(1,5)])
+
+ assert (a[1:] == b).all()
+ assert (a[1:,newaxis] == d).all()
+ assert (a[newaxis,1:] == c).all()
+
+ def test_newaxis_assign(self):
+ from _numpypy import array
+ from numpypy.core.numeric import newaxis
+
+ a = array(range(5))
+ a[newaxis,1] = [2]
+ assert a[1] == 2
+
+ def test_newaxis_virtual(self):
+ from _numpypy import array
+ from numpypy.core.numeric import newaxis
+
+ a = array(range(5))
+ b = (a + a)[newaxis]
+ c = array([[0, 2, 4, 6, 8]])
+ assert (b == c).all()
+
+ def test_newaxis_then_slice(self):
+ from _numpypy import array
+ from numpypy.core.numeric import newaxis
+ a = array(range(5))
+ b = a[newaxis]
+ assert (b[0,1:] == a[1:]).all()
+
+ def test_slice_then_newaxis(self):
+ from _numpypy import array
+ from numpypy.core.numeric import newaxis
+ a = array(range(5))
+ b = a[2:]
+ assert (b[newaxis] == [[2, 3, 4]]).all()
+
def test_scalar(self):
from _numpypy import array, dtype
a = array(3)
_______________________________________________
pypy-commit mailing list
[email protected]
http://mail.python.org/mailman/listinfo/pypy-commit