Author: Brian Kearns <bdkea...@gmail.com> Branch: Changeset: r67695:be671931b570 Date: 2013-10-29 12:41 -0400 http://bitbucket.org/pypy/pypy/changeset/be671931b570/
Log: test and fix numpy void getitem behavior diff --git a/pypy/module/micronumpy/interp_boxes.py b/pypy/module/micronumpy/interp_boxes.py --- a/pypy/module/micronumpy/interp_boxes.py +++ b/pypy/module/micronumpy/interp_boxes.py @@ -378,25 +378,27 @@ class W_VoidBox(W_FlexibleBox): def descr_getitem(self, space, w_item): - from pypy.module.micronumpy.types import VoidType - if space.isinstance_w(w_item, space.w_str): + if space.isinstance_w(w_item, space.w_basestring): item = space.str_w(w_item) elif space.isinstance_w(w_item, space.w_int): - #Called by iterator protocol indx = space.int_w(w_item) try: item = self.dtype.fieldnames[indx] except IndexError: - raise OperationError(space.w_IndexError, - space.wrap("Iterated over too many fields %d" % indx)) + if indx < 0: + indx += len(self.dtype.fieldnames) + raise OperationError(space.w_IndexError, space.wrap( + "invalid index (%d)" % indx)) else: raise OperationError(space.w_IndexError, space.wrap( - "Can only access fields of record with int or str")) + "invalid index")) try: ofs, dtype = self.dtype.fields[item] except KeyError: - raise OperationError(space.w_IndexError, - space.wrap("Field %s does not exist" % item)) + raise OperationError(space.w_IndexError, space.wrap( + "invalid index")) + + from pypy.module.micronumpy.types import VoidType if isinstance(dtype.itemtype, VoidType): read_val = dtype.itemtype.readarray(self.arr, self.ofs, ofs, dtype) else: diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py --- a/pypy/module/micronumpy/test/test_numarray.py +++ b/pypy/module/micronumpy/test/test_numarray.py @@ -2924,10 +2924,22 @@ d = dtype([("x", "int", 3), ("y", "float", 5)]) a = array([([1, 2, 3], [0.5, 1.5, 2.5, 3.5, 4.5]), ([4, 5, 6], [5.5, 6.5, 7.5, 8.5, 9.5])], dtype=d) - assert (a[0]["x"] == [1, 2, 3]).all() - assert (a[0]["y"] == [0.5, 1.5, 2.5, 3.5, 4.5]).all() - assert (a[1]["x"] == [4, 5, 6]).all() - assert (a[1]["y"] == [5.5, 6.5, 7.5, 8.5, 9.5]).all() + for v in ['x', u'x', 0, -2]: + assert (a[0][v] == [1, 2, 3]).all() + assert (a[1][v] == [4, 5, 6]).all() + for v in ['y', u'y', -1, 1]: + assert (a[0][v] == [0.5, 1.5, 2.5, 3.5, 4.5]).all() + assert (a[1][v] == [5.5, 6.5, 7.5, 8.5, 9.5]).all() + for v in [-3, 2]: + exc = raises(IndexError, "a[0][%d]" % v) + assert exc.value.message == "invalid index (%d)" % (v + 2 if v < 0 else v) + exc = raises(IndexError, "a[0]['z']") + assert exc.value.message == "invalid index" + exc = raises(IndexError, "a[0][None]") + assert exc.value.message == "invalid index" + + exc = raises(IndexError, "a[0][None]") + assert exc.value.message == 'invalid index' a[0]["x"][0] = 200 assert a[0]["x"][0] == 200 _______________________________________________ pypy-commit mailing list pypy-commit@python.org https://mail.python.org/mailman/listinfo/pypy-commit