This is an automated email from the ASF dual-hosted git repository.

uwe pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 777f986  ARROW-2331: [Python] Fix indexing for negative or 
out-of-bounds indices
777f986 is described below

commit 777f986129594dd8859d1215dbf89ec2495b6904
Author: Antoine Pitrou <[email protected]>
AuthorDate: Fri Mar 23 17:03:29 2018 +0100

    ARROW-2331: [Python] Fix indexing for negative or out-of-bounds indices
    
    Author: Antoine Pitrou <[email protected]>
    
    Closes #1770 from pitrou/ARROW-2331-python-indexing and squashes the 
following commits:
    
    aec1ef0 <Antoine Pitrou> Try to fix downcast errors
    1a38451 <Antoine Pitrou> ARROW-2331:  Fix indexing for negative or 
out-of-bounds indices
---
 python/pyarrow/array.pxi             | 19 +++++++++++++++----
 python/pyarrow/lib.pxd               |  2 ++
 python/pyarrow/scalar.pxi            |  7 +++++--
 python/pyarrow/table.pxi             | 19 ++++++++-----------
 python/pyarrow/tests/test_array.py   | 14 ++++++++++++++
 python/pyarrow/tests/test_scalars.py |  6 ++++++
 python/pyarrow/tests/test_table.py   |  6 ++++++
 python/pyarrow/types.pxi             | 30 +++++++-----------------------
 8 files changed, 63 insertions(+), 40 deletions(-)

diff --git a/python/pyarrow/array.pxi b/python/pyarrow/array.pxi
index afb68a2..c40d7b5 100644
--- a/python/pyarrow/array.pxi
+++ b/python/pyarrow/array.pxi
@@ -232,6 +232,17 @@ def _normalize_slice(object arrow_obj, slice key):
         return arrow_obj.slice(start, stop - start)
 
 
+cdef Py_ssize_t _normalize_index(Py_ssize_t index,
+                                 Py_ssize_t length) except -1:
+    if index < 0:
+        index += length
+        if index < 0:
+            raise IndexError("index out of bounds")
+    elif index >= length:
+        raise IndexError("index out of bounds")
+    return index
+
+
 cdef class _FunctionContext:
     cdef:
         unique_ptr[CFunctionContext] ctx
@@ -427,6 +438,9 @@ cdef class Array:
         return self.ap.Equals(deref(other.ap))
 
     def __len__(self):
+        return self.length()
+
+    cdef int64_t length(self):
         if self.sp_array.get():
             return self.sp_array.get().length()
         else:
@@ -441,10 +455,7 @@ cdef class Array:
         if PySlice_Check(key):
             return _normalize_slice(self, key)
 
-        while key < 0:
-            key += len(self)
-
-        return self.getitem(key)
+        return self.getitem(_normalize_index(key, self.length()))
 
     cdef getitem(self, int64_t i):
         return box_scalar(self.type, self.sp_array, i)
diff --git a/python/pyarrow/lib.pxd b/python/pyarrow/lib.pxd
index be103b3..7a1b221 100644
--- a/python/pyarrow/lib.pxd
+++ b/python/pyarrow/lib.pxd
@@ -137,6 +137,7 @@ cdef class ListValue(ArrayValue):
         CListArray* ap
 
     cdef getitem(self, int64_t i)
+    cdef int64_t length(self)
 
 
 cdef class UnionValue(ArrayValue):
@@ -164,6 +165,7 @@ cdef class Array:
 
     cdef void init(self, const shared_ptr[CArray]& sp_array)
     cdef getitem(self, int64_t i)
+    cdef int64_t length(self)
 
 
 cdef class Tensor:
diff --git a/python/pyarrow/scalar.pxi b/python/pyarrow/scalar.pxi
index a0f8480..e9c2c7c 100644
--- a/python/pyarrow/scalar.pxi
+++ b/python/pyarrow/scalar.pxi
@@ -294,10 +294,10 @@ cdef class BinaryValue(ArrayValue):
 cdef class ListValue(ArrayValue):
 
     def __len__(self):
-        return self.ap.value_length(self.index)
+        return self.length()
 
     def __getitem__(self, i):
-        return self.getitem(i)
+        return self.getitem(_normalize_index(i, self.length()))
 
     def __iter__(self):
         for i in range(len(self)):
@@ -313,6 +313,9 @@ cdef class ListValue(ArrayValue):
         cdef int64_t j = self.ap.value_offset(self.index) + i
         return box_scalar(self.value_type, self.ap.values(), j)
 
+    cdef int64_t length(self):
+        return self.ap.value_length(self.index)
+
     def as_py(self):
         cdef:
             int64_t j
diff --git a/python/pyarrow/table.pxi b/python/pyarrow/table.pxi
index 6cfa987..672b9fb 100644
--- a/python/pyarrow/table.pxi
+++ b/python/pyarrow/table.pxi
@@ -84,14 +84,12 @@ cdef class ChunkedArray:
         if isinstance(key, slice):
             return _normalize_slice(self, key)
         elif isinstance(key, six.integer_types):
-            item = key
-            if item >= self.chunked_array.length() or item < 0:
-                return IndexError("ChunkedArray selection out of bounds")
+            index = _normalize_index(key, self.chunked_array.length())
             for i in range(self.num_chunks):
-                if item < self.chunked_array.chunk(i).get().length():
-                    return self.chunk(i)[item]
+                if index < self.chunked_array.chunk(i).get().length():
+                    return self.chunk(i)[index]
                 else:
-                    item -= self.chunked_array.chunk(i).get().length()
+                    index -= self.chunked_array.chunk(i).get().length()
         else:
             raise TypeError("key must either be a slice or integer")
 
@@ -630,12 +628,10 @@ cdef class RecordBatch:
         return pyarrow_wrap_array(self.batch.column(i))
 
     def __getitem__(self, key):
-        cdef:
-            Py_ssize_t start, stop
         if isinstance(key, slice):
             return _normalize_slice(self, key)
         else:
-            return self.column(key)
+            return self.column(_normalize_index(key, self.num_columns))
 
     def serialize(self, memory_pool=None):
         """
@@ -1183,8 +1179,9 @@ cdef class Table:
         column.init(self.table.column(index))
         return column
 
-    def __getitem__(self, int64_t i):
-        return self.column(i)
+    def __getitem__(self, key):
+        cdef int index = <int> _normalize_index(key, self.num_columns)
+        return self.column(index)
 
     def itercolumns(self):
         """
diff --git a/python/pyarrow/tests/test_array.py 
b/python/pyarrow/tests/test_array.py
index d126db3..4a441fb 100644
--- a/python/pyarrow/tests/test_array.py
+++ b/python/pyarrow/tests/test_array.py
@@ -110,6 +110,20 @@ def test_to_pandas_zero_copy():
         np_arr.sum()
 
 
+def test_array_getitem():
+    arr = pa.array(range(10, 15))
+    lst = arr.to_pylist()
+
+    for idx in range(-len(arr), len(arr)):
+        assert arr[idx].as_py() == lst[idx]
+    for idx in range(-2 * len(arr), -len(arr)):
+        with pytest.raises(IndexError):
+            arr[idx]
+    for idx in range(len(arr), 2 * len(arr)):
+        with pytest.raises(IndexError):
+            arr[idx]
+
+
 def test_array_slice():
     arr = pa.array(range(10))
 
diff --git a/python/pyarrow/tests/test_scalars.py 
b/python/pyarrow/tests/test_scalars.py
index 7061a0d..c63be02 100644
--- a/python/pyarrow/tests/test_scalars.py
+++ b/python/pyarrow/tests/test_scalars.py
@@ -131,6 +131,12 @@ class TestScalars(unittest.TestCase):
         assert v.as_py() == ['foo', None]
         assert v[0].as_py() == 'foo'
         assert v[1] is pa.NA
+        assert v[-1] == v[1]
+        assert v[-2] == v[0]
+        with pytest.raises(IndexError):
+            v[-3]
+        with pytest.raises(IndexError):
+            v[2]
 
         assert arr[1] is pa.NA
 
diff --git a/python/pyarrow/tests/test_table.py 
b/python/pyarrow/tests/test_table.py
index 356ecb7..8156435 100644
--- a/python/pyarrow/tests/test_table.py
+++ b/python/pyarrow/tests/test_table.py
@@ -31,6 +31,12 @@ def test_chunked_array_getitem():
     ]
     data = pa.chunked_array(data)
     assert data[1].as_py() == 2
+    assert data[-1].as_py() == 6
+    assert data[-6].as_py() == 1
+    with pytest.raises(IndexError):
+        data[6]
+    with pytest.raises(IndexError):
+        data[-7]
 
     data_slice = data[2:4]
     assert data_slice.to_pylist() == [3, 4]
diff --git a/python/pyarrow/types.pxi b/python/pyarrow/types.pxi
index a4391c7..b0557eb 100644
--- a/python/pyarrow/types.pxi
+++ b/python/pyarrow/types.pxi
@@ -166,10 +166,8 @@ cdef class StructType(DataType):
         DataType.init(self, type)
 
     def __getitem__(self, i):
-        if i < 0 or i >= self.num_children:
-            raise IndexError(i)
-
-        return pyarrow_wrap_field(self.type.child(i))
+        cdef int index = <int> _normalize_index(i, self.num_children)
+        return pyarrow_wrap_field(self.type.child(index))
 
     property num_children:
 
@@ -207,7 +205,8 @@ cdef class UnionType(DataType):
             assert 0
 
     def __getitem__(self, i):
-        return pyarrow_wrap_field(self.type.child(i))
+        cdef int index = <int> _normalize_index(i, self.num_children)
+        return pyarrow_wrap_field(self.type.child(index))
 
     def __getstate__(self):
         children = [self[i] for i in range(self.num_children)]
@@ -440,24 +439,9 @@ cdef class Schema:
     def __len__(self):
         return self.schema.num_fields()
 
-    def __getitem__(self, int i):
-        cdef:
-            Field result = Field()
-            int num_fields = self.schema.num_fields()
-            int index
-
-        if not -num_fields <= i < num_fields:
-            raise IndexError(
-                'Schema field index {:d} is out of range'.format(i)
-            )
-
-        index = i if i >= 0 else num_fields + i
-        assert index >= 0
-
-        result.init(self.schema.field(index))
-        result.type = pyarrow_wrap_data_type(result.field.type())
-
-        return result
+    def __getitem__(self, key):
+        cdef int index = <int> _normalize_index(key, self.schema.num_fields())
+        return pyarrow_wrap_field(self.schema.field(index))
 
     def __iter__(self):
         for i in range(len(self)):

-- 
To stop receiving notification emails like this one, please contact
[email protected].

Reply via email to