This is an automated email from the ASF dual-hosted git repository.
uwe pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new 777f986 ARROW-2331: [Python] Fix indexing for negative or
out-of-bounds indices
777f986 is described below
commit 777f986129594dd8859d1215dbf89ec2495b6904
Author: Antoine Pitrou <[email protected]>
AuthorDate: Fri Mar 23 17:03:29 2018 +0100
ARROW-2331: [Python] Fix indexing for negative or out-of-bounds indices
Author: Antoine Pitrou <[email protected]>
Closes #1770 from pitrou/ARROW-2331-python-indexing and squashes the
following commits:
aec1ef0 <Antoine Pitrou> Try to fix downcast errors
1a38451 <Antoine Pitrou> ARROW-2331: Fix indexing for negative or
out-of-bounds indices
---
python/pyarrow/array.pxi | 19 +++++++++++++++----
python/pyarrow/lib.pxd | 2 ++
python/pyarrow/scalar.pxi | 7 +++++--
python/pyarrow/table.pxi | 19 ++++++++-----------
python/pyarrow/tests/test_array.py | 14 ++++++++++++++
python/pyarrow/tests/test_scalars.py | 6 ++++++
python/pyarrow/tests/test_table.py | 6 ++++++
python/pyarrow/types.pxi | 30 +++++++-----------------------
8 files changed, 63 insertions(+), 40 deletions(-)
diff --git a/python/pyarrow/array.pxi b/python/pyarrow/array.pxi
index afb68a2..c40d7b5 100644
--- a/python/pyarrow/array.pxi
+++ b/python/pyarrow/array.pxi
@@ -232,6 +232,17 @@ def _normalize_slice(object arrow_obj, slice key):
return arrow_obj.slice(start, stop - start)
+cdef Py_ssize_t _normalize_index(Py_ssize_t index,
+ Py_ssize_t length) except -1:
+ if index < 0:
+ index += length
+ if index < 0:
+ raise IndexError("index out of bounds")
+ elif index >= length:
+ raise IndexError("index out of bounds")
+ return index
+
+
cdef class _FunctionContext:
cdef:
unique_ptr[CFunctionContext] ctx
@@ -427,6 +438,9 @@ cdef class Array:
return self.ap.Equals(deref(other.ap))
def __len__(self):
+ return self.length()
+
+ cdef int64_t length(self):
if self.sp_array.get():
return self.sp_array.get().length()
else:
@@ -441,10 +455,7 @@ cdef class Array:
if PySlice_Check(key):
return _normalize_slice(self, key)
- while key < 0:
- key += len(self)
-
- return self.getitem(key)
+ return self.getitem(_normalize_index(key, self.length()))
cdef getitem(self, int64_t i):
return box_scalar(self.type, self.sp_array, i)
diff --git a/python/pyarrow/lib.pxd b/python/pyarrow/lib.pxd
index be103b3..7a1b221 100644
--- a/python/pyarrow/lib.pxd
+++ b/python/pyarrow/lib.pxd
@@ -137,6 +137,7 @@ cdef class ListValue(ArrayValue):
CListArray* ap
cdef getitem(self, int64_t i)
+ cdef int64_t length(self)
cdef class UnionValue(ArrayValue):
@@ -164,6 +165,7 @@ cdef class Array:
cdef void init(self, const shared_ptr[CArray]& sp_array)
cdef getitem(self, int64_t i)
+ cdef int64_t length(self)
cdef class Tensor:
diff --git a/python/pyarrow/scalar.pxi b/python/pyarrow/scalar.pxi
index a0f8480..e9c2c7c 100644
--- a/python/pyarrow/scalar.pxi
+++ b/python/pyarrow/scalar.pxi
@@ -294,10 +294,10 @@ cdef class BinaryValue(ArrayValue):
cdef class ListValue(ArrayValue):
def __len__(self):
- return self.ap.value_length(self.index)
+ return self.length()
def __getitem__(self, i):
- return self.getitem(i)
+ return self.getitem(_normalize_index(i, self.length()))
def __iter__(self):
for i in range(len(self)):
@@ -313,6 +313,9 @@ cdef class ListValue(ArrayValue):
cdef int64_t j = self.ap.value_offset(self.index) + i
return box_scalar(self.value_type, self.ap.values(), j)
+ cdef int64_t length(self):
+ return self.ap.value_length(self.index)
+
def as_py(self):
cdef:
int64_t j
diff --git a/python/pyarrow/table.pxi b/python/pyarrow/table.pxi
index 6cfa987..672b9fb 100644
--- a/python/pyarrow/table.pxi
+++ b/python/pyarrow/table.pxi
@@ -84,14 +84,12 @@ cdef class ChunkedArray:
if isinstance(key, slice):
return _normalize_slice(self, key)
elif isinstance(key, six.integer_types):
- item = key
- if item >= self.chunked_array.length() or item < 0:
- return IndexError("ChunkedArray selection out of bounds")
+ index = _normalize_index(key, self.chunked_array.length())
for i in range(self.num_chunks):
- if item < self.chunked_array.chunk(i).get().length():
- return self.chunk(i)[item]
+ if index < self.chunked_array.chunk(i).get().length():
+ return self.chunk(i)[index]
else:
- item -= self.chunked_array.chunk(i).get().length()
+ index -= self.chunked_array.chunk(i).get().length()
else:
raise TypeError("key must either be a slice or integer")
@@ -630,12 +628,10 @@ cdef class RecordBatch:
return pyarrow_wrap_array(self.batch.column(i))
def __getitem__(self, key):
- cdef:
- Py_ssize_t start, stop
if isinstance(key, slice):
return _normalize_slice(self, key)
else:
- return self.column(key)
+ return self.column(_normalize_index(key, self.num_columns))
def serialize(self, memory_pool=None):
"""
@@ -1183,8 +1179,9 @@ cdef class Table:
column.init(self.table.column(index))
return column
- def __getitem__(self, int64_t i):
- return self.column(i)
+ def __getitem__(self, key):
+ cdef int index = <int> _normalize_index(key, self.num_columns)
+ return self.column(index)
def itercolumns(self):
"""
diff --git a/python/pyarrow/tests/test_array.py
b/python/pyarrow/tests/test_array.py
index d126db3..4a441fb 100644
--- a/python/pyarrow/tests/test_array.py
+++ b/python/pyarrow/tests/test_array.py
@@ -110,6 +110,20 @@ def test_to_pandas_zero_copy():
np_arr.sum()
+def test_array_getitem():
+ arr = pa.array(range(10, 15))
+ lst = arr.to_pylist()
+
+ for idx in range(-len(arr), len(arr)):
+ assert arr[idx].as_py() == lst[idx]
+ for idx in range(-2 * len(arr), -len(arr)):
+ with pytest.raises(IndexError):
+ arr[idx]
+ for idx in range(len(arr), 2 * len(arr)):
+ with pytest.raises(IndexError):
+ arr[idx]
+
+
def test_array_slice():
arr = pa.array(range(10))
diff --git a/python/pyarrow/tests/test_scalars.py
b/python/pyarrow/tests/test_scalars.py
index 7061a0d..c63be02 100644
--- a/python/pyarrow/tests/test_scalars.py
+++ b/python/pyarrow/tests/test_scalars.py
@@ -131,6 +131,12 @@ class TestScalars(unittest.TestCase):
assert v.as_py() == ['foo', None]
assert v[0].as_py() == 'foo'
assert v[1] is pa.NA
+ assert v[-1] == v[1]
+ assert v[-2] == v[0]
+ with pytest.raises(IndexError):
+ v[-3]
+ with pytest.raises(IndexError):
+ v[2]
assert arr[1] is pa.NA
diff --git a/python/pyarrow/tests/test_table.py
b/python/pyarrow/tests/test_table.py
index 356ecb7..8156435 100644
--- a/python/pyarrow/tests/test_table.py
+++ b/python/pyarrow/tests/test_table.py
@@ -31,6 +31,12 @@ def test_chunked_array_getitem():
]
data = pa.chunked_array(data)
assert data[1].as_py() == 2
+ assert data[-1].as_py() == 6
+ assert data[-6].as_py() == 1
+ with pytest.raises(IndexError):
+ data[6]
+ with pytest.raises(IndexError):
+ data[-7]
data_slice = data[2:4]
assert data_slice.to_pylist() == [3, 4]
diff --git a/python/pyarrow/types.pxi b/python/pyarrow/types.pxi
index a4391c7..b0557eb 100644
--- a/python/pyarrow/types.pxi
+++ b/python/pyarrow/types.pxi
@@ -166,10 +166,8 @@ cdef class StructType(DataType):
DataType.init(self, type)
def __getitem__(self, i):
- if i < 0 or i >= self.num_children:
- raise IndexError(i)
-
- return pyarrow_wrap_field(self.type.child(i))
+ cdef int index = <int> _normalize_index(i, self.num_children)
+ return pyarrow_wrap_field(self.type.child(index))
property num_children:
@@ -207,7 +205,8 @@ cdef class UnionType(DataType):
assert 0
def __getitem__(self, i):
- return pyarrow_wrap_field(self.type.child(i))
+ cdef int index = <int> _normalize_index(i, self.num_children)
+ return pyarrow_wrap_field(self.type.child(index))
def __getstate__(self):
children = [self[i] for i in range(self.num_children)]
@@ -440,24 +439,9 @@ cdef class Schema:
def __len__(self):
return self.schema.num_fields()
- def __getitem__(self, int i):
- cdef:
- Field result = Field()
- int num_fields = self.schema.num_fields()
- int index
-
- if not -num_fields <= i < num_fields:
- raise IndexError(
- 'Schema field index {:d} is out of range'.format(i)
- )
-
- index = i if i >= 0 else num_fields + i
- assert index >= 0
-
- result.init(self.schema.field(index))
- result.type = pyarrow_wrap_data_type(result.field.type())
-
- return result
+ def __getitem__(self, key):
+ cdef int index = <int> _normalize_index(key, self.schema.num_fields())
+ return pyarrow_wrap_field(self.schema.field(index))
def __iter__(self):
for i in range(len(self)):
--
To stop receiving notification emails like this one, please contact
[email protected].