Author: Brian Kearns <[email protected]>
Branch:
Changeset: r67695:be671931b570
Date: 2013-10-29 12:41 -0400
http://bitbucket.org/pypy/pypy/changeset/be671931b570/
Log: test and fix numpy void getitem behavior
diff --git a/pypy/module/micronumpy/interp_boxes.py
b/pypy/module/micronumpy/interp_boxes.py
--- a/pypy/module/micronumpy/interp_boxes.py
+++ b/pypy/module/micronumpy/interp_boxes.py
@@ -378,25 +378,27 @@
class W_VoidBox(W_FlexibleBox):
def descr_getitem(self, space, w_item):
- from pypy.module.micronumpy.types import VoidType
- if space.isinstance_w(w_item, space.w_str):
+ if space.isinstance_w(w_item, space.w_basestring):
item = space.str_w(w_item)
elif space.isinstance_w(w_item, space.w_int):
- #Called by iterator protocol
indx = space.int_w(w_item)
try:
item = self.dtype.fieldnames[indx]
except IndexError:
- raise OperationError(space.w_IndexError,
- space.wrap("Iterated over too many fields %d" % indx))
+ if indx < 0:
+ indx += len(self.dtype.fieldnames)
+ raise OperationError(space.w_IndexError, space.wrap(
+ "invalid index (%d)" % indx))
else:
raise OperationError(space.w_IndexError, space.wrap(
- "Can only access fields of record with int or str"))
+ "invalid index"))
try:
ofs, dtype = self.dtype.fields[item]
except KeyError:
- raise OperationError(space.w_IndexError,
- space.wrap("Field %s does not exist" % item))
+ raise OperationError(space.w_IndexError, space.wrap(
+ "invalid index"))
+
+ from pypy.module.micronumpy.types import VoidType
if isinstance(dtype.itemtype, VoidType):
read_val = dtype.itemtype.readarray(self.arr, self.ofs, ofs, dtype)
else:
diff --git a/pypy/module/micronumpy/test/test_numarray.py
b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -2924,10 +2924,22 @@
d = dtype([("x", "int", 3), ("y", "float", 5)])
a = array([([1, 2, 3], [0.5, 1.5, 2.5, 3.5, 4.5]), ([4, 5, 6], [5.5,
6.5, 7.5, 8.5, 9.5])], dtype=d)
- assert (a[0]["x"] == [1, 2, 3]).all()
- assert (a[0]["y"] == [0.5, 1.5, 2.5, 3.5, 4.5]).all()
- assert (a[1]["x"] == [4, 5, 6]).all()
- assert (a[1]["y"] == [5.5, 6.5, 7.5, 8.5, 9.5]).all()
+ for v in ['x', u'x', 0, -2]:
+ assert (a[0][v] == [1, 2, 3]).all()
+ assert (a[1][v] == [4, 5, 6]).all()
+ for v in ['y', u'y', -1, 1]:
+ assert (a[0][v] == [0.5, 1.5, 2.5, 3.5, 4.5]).all()
+ assert (a[1][v] == [5.5, 6.5, 7.5, 8.5, 9.5]).all()
+ for v in [-3, 2]:
+ exc = raises(IndexError, "a[0][%d]" % v)
+ assert exc.value.message == "invalid index (%d)" % (v + 2 if v < 0
else v)
+ exc = raises(IndexError, "a[0]['z']")
+ assert exc.value.message == "invalid index"
+ exc = raises(IndexError, "a[0][None]")
+ assert exc.value.message == "invalid index"
+
+ exc = raises(IndexError, "a[0][None]")
+ assert exc.value.message == 'invalid index'
a[0]["x"][0] = 200
assert a[0]["x"][0] == 200
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit