[
https://issues.apache.org/jira/browse/ARROW-2331?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=16411623#comment-16411623
]
ASF GitHub Bot commented on ARROW-2331:
---------------------------------------
xhochy closed pull request #1770: ARROW-2331: [Python] Fix indexing for
negative or out-of-bounds indices
URL: https://github.com/apache/arrow/pull/1770
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/python/pyarrow/array.pxi b/python/pyarrow/array.pxi
index afb68a2fb..c40d7b554 100644
--- a/python/pyarrow/array.pxi
+++ b/python/pyarrow/array.pxi
@@ -232,6 +232,17 @@ def _normalize_slice(object arrow_obj, slice key):
return arrow_obj.slice(start, stop - start)
+cdef Py_ssize_t _normalize_index(Py_ssize_t index,
+ Py_ssize_t length) except -1:
+ if index < 0:
+ index += length
+ if index < 0:
+ raise IndexError("index out of bounds")
+ elif index >= length:
+ raise IndexError("index out of bounds")
+ return index
+
+
cdef class _FunctionContext:
cdef:
unique_ptr[CFunctionContext] ctx
@@ -427,6 +438,9 @@ cdef class Array:
return self.ap.Equals(deref(other.ap))
def __len__(self):
+ return self.length()
+
+ cdef int64_t length(self):
if self.sp_array.get():
return self.sp_array.get().length()
else:
@@ -441,10 +455,7 @@ cdef class Array:
if PySlice_Check(key):
return _normalize_slice(self, key)
- while key < 0:
- key += len(self)
-
- return self.getitem(key)
+ return self.getitem(_normalize_index(key, self.length()))
cdef getitem(self, int64_t i):
return box_scalar(self.type, self.sp_array, i)
diff --git a/python/pyarrow/lib.pxd b/python/pyarrow/lib.pxd
index be103b354..7a1b221db 100644
--- a/python/pyarrow/lib.pxd
+++ b/python/pyarrow/lib.pxd
@@ -137,6 +137,7 @@ cdef class ListValue(ArrayValue):
CListArray* ap
cdef getitem(self, int64_t i)
+ cdef int64_t length(self)
cdef class UnionValue(ArrayValue):
@@ -164,6 +165,7 @@ cdef class Array:
cdef void init(self, const shared_ptr[CArray]& sp_array)
cdef getitem(self, int64_t i)
+ cdef int64_t length(self)
cdef class Tensor:
diff --git a/python/pyarrow/scalar.pxi b/python/pyarrow/scalar.pxi
index a0f8480db..e9c2c7c84 100644
--- a/python/pyarrow/scalar.pxi
+++ b/python/pyarrow/scalar.pxi
@@ -294,10 +294,10 @@ cdef class BinaryValue(ArrayValue):
cdef class ListValue(ArrayValue):
def __len__(self):
- return self.ap.value_length(self.index)
+ return self.length()
def __getitem__(self, i):
- return self.getitem(i)
+ return self.getitem(_normalize_index(i, self.length()))
def __iter__(self):
for i in range(len(self)):
@@ -313,6 +313,9 @@ cdef class ListValue(ArrayValue):
cdef int64_t j = self.ap.value_offset(self.index) + i
return box_scalar(self.value_type, self.ap.values(), j)
+ cdef int64_t length(self):
+ return self.ap.value_length(self.index)
+
def as_py(self):
cdef:
int64_t j
diff --git a/python/pyarrow/table.pxi b/python/pyarrow/table.pxi
index 6cfa9873b..672b9fb7d 100644
--- a/python/pyarrow/table.pxi
+++ b/python/pyarrow/table.pxi
@@ -84,14 +84,12 @@ cdef class ChunkedArray:
if isinstance(key, slice):
return _normalize_slice(self, key)
elif isinstance(key, six.integer_types):
- item = key
- if item >= self.chunked_array.length() or item < 0:
- return IndexError("ChunkedArray selection out of bounds")
+ index = _normalize_index(key, self.chunked_array.length())
for i in range(self.num_chunks):
- if item < self.chunked_array.chunk(i).get().length():
- return self.chunk(i)[item]
+ if index < self.chunked_array.chunk(i).get().length():
+ return self.chunk(i)[index]
else:
- item -= self.chunked_array.chunk(i).get().length()
+ index -= self.chunked_array.chunk(i).get().length()
else:
raise TypeError("key must either be a slice or integer")
@@ -630,12 +628,10 @@ cdef class RecordBatch:
return pyarrow_wrap_array(self.batch.column(i))
def __getitem__(self, key):
- cdef:
- Py_ssize_t start, stop
if isinstance(key, slice):
return _normalize_slice(self, key)
else:
- return self.column(key)
+ return self.column(_normalize_index(key, self.num_columns))
def serialize(self, memory_pool=None):
"""
@@ -1183,8 +1179,9 @@ cdef class Table:
column.init(self.table.column(index))
return column
- def __getitem__(self, int64_t i):
- return self.column(i)
+ def __getitem__(self, key):
+ cdef int index = <int> _normalize_index(key, self.num_columns)
+ return self.column(index)
def itercolumns(self):
"""
diff --git a/python/pyarrow/tests/test_array.py
b/python/pyarrow/tests/test_array.py
index d126db373..4a441fb97 100644
--- a/python/pyarrow/tests/test_array.py
+++ b/python/pyarrow/tests/test_array.py
@@ -110,6 +110,20 @@ def test_to_pandas_zero_copy():
np_arr.sum()
+def test_array_getitem():
+ arr = pa.array(range(10, 15))
+ lst = arr.to_pylist()
+
+ for idx in range(-len(arr), len(arr)):
+ assert arr[idx].as_py() == lst[idx]
+ for idx in range(-2 * len(arr), -len(arr)):
+ with pytest.raises(IndexError):
+ arr[idx]
+ for idx in range(len(arr), 2 * len(arr)):
+ with pytest.raises(IndexError):
+ arr[idx]
+
+
def test_array_slice():
arr = pa.array(range(10))
diff --git a/python/pyarrow/tests/test_scalars.py
b/python/pyarrow/tests/test_scalars.py
index 7061a0d3a..c63be0203 100644
--- a/python/pyarrow/tests/test_scalars.py
+++ b/python/pyarrow/tests/test_scalars.py
@@ -131,6 +131,12 @@ def test_list(self):
assert v.as_py() == ['foo', None]
assert v[0].as_py() == 'foo'
assert v[1] is pa.NA
+ assert v[-1] == v[1]
+ assert v[-2] == v[0]
+ with pytest.raises(IndexError):
+ v[-3]
+ with pytest.raises(IndexError):
+ v[2]
assert arr[1] is pa.NA
diff --git a/python/pyarrow/tests/test_table.py
b/python/pyarrow/tests/test_table.py
index 356ecb7e0..81564352b 100644
--- a/python/pyarrow/tests/test_table.py
+++ b/python/pyarrow/tests/test_table.py
@@ -31,6 +31,12 @@ def test_chunked_array_getitem():
]
data = pa.chunked_array(data)
assert data[1].as_py() == 2
+ assert data[-1].as_py() == 6
+ assert data[-6].as_py() == 1
+ with pytest.raises(IndexError):
+ data[6]
+ with pytest.raises(IndexError):
+ data[-7]
data_slice = data[2:4]
assert data_slice.to_pylist() == [3, 4]
diff --git a/python/pyarrow/types.pxi b/python/pyarrow/types.pxi
index a4391c7f9..b0557eb57 100644
--- a/python/pyarrow/types.pxi
+++ b/python/pyarrow/types.pxi
@@ -166,10 +166,8 @@ cdef class StructType(DataType):
DataType.init(self, type)
def __getitem__(self, i):
- if i < 0 or i >= self.num_children:
- raise IndexError(i)
-
- return pyarrow_wrap_field(self.type.child(i))
+ cdef int index = <int> _normalize_index(i, self.num_children)
+ return pyarrow_wrap_field(self.type.child(index))
property num_children:
@@ -207,7 +205,8 @@ cdef class UnionType(DataType):
assert 0
def __getitem__(self, i):
- return pyarrow_wrap_field(self.type.child(i))
+ cdef int index = <int> _normalize_index(i, self.num_children)
+ return pyarrow_wrap_field(self.type.child(index))
def __getstate__(self):
children = [self[i] for i in range(self.num_children)]
@@ -440,24 +439,9 @@ cdef class Schema:
def __len__(self):
return self.schema.num_fields()
- def __getitem__(self, int i):
- cdef:
- Field result = Field()
- int num_fields = self.schema.num_fields()
- int index
-
- if not -num_fields <= i < num_fields:
- raise IndexError(
- 'Schema field index {:d} is out of range'.format(i)
- )
-
- index = i if i >= 0 else num_fields + i
- assert index >= 0
-
- result.init(self.schema.field(index))
- result.type = pyarrow_wrap_data_type(result.field.type())
-
- return result
+ def __getitem__(self, key):
+ cdef int index = <int> _normalize_index(key, self.schema.num_fields())
+ return pyarrow_wrap_field(self.schema.field(index))
def __iter__(self):
for i in range(len(self)):
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
> [Python] Fix indexing implementations
> -------------------------------------
>
> Key: ARROW-2331
> URL: https://issues.apache.org/jira/browse/ARROW-2331
> Project: Apache Arrow
> Issue Type: Bug
> Components: Python
> Affects Versions: 0.9.0
> Reporter: Antoine Pitrou
> Assignee: Antoine Pitrou
> Priority: Minor
> Labels: pull-request-available
>
> A number of {{\_\_getitem\_\_}} implementations handle negative or
> out-of-bounds indices improperly, for example:
> {code:python}
> >>> a = pa.array([11,12,13])
> >>> a[-6]
> 11
> >>> a[-15]
> 11
> >>> a[4]
> NA
> >>> a[3]
> NA
> >>> a[1111]
> NA
> {code}
--
This message was sent by Atlassian JIRA
(v7.6.3#76005)