Repository: arrow
Updated Branches:
  refs/heads/master 5aea3a3d9 -> b4e9ba1ae


ARROW-968: [Python] Support slices in RecordBatch.__getitem__

Author: Wes McKinney <[email protected]>

Closes #908 from wesm/ARROW-968 and squashes the following commits:

47b71a5d [Wes McKinney] Support slices in RecordBatch.__getitem__


Project: http://git-wip-us.apache.org/repos/asf/arrow/repo
Commit: http://git-wip-us.apache.org/repos/asf/arrow/commit/b4e9ba1a
Tree: http://git-wip-us.apache.org/repos/asf/arrow/tree/b4e9ba1a
Diff: http://git-wip-us.apache.org/repos/asf/arrow/diff/b4e9ba1a

Branch: refs/heads/master
Commit: b4e9ba1ae68bcc449e4426b7c08d2984ed20c6be
Parents: 5aea3a3
Author: Wes McKinney <[email protected]>
Authored: Sat Jul 29 11:00:58 2017 -0400
Committer: Wes McKinney <[email protected]>
Committed: Sat Jul 29 11:00:58 2017 -0400

----------------------------------------------------------------------
 python/pyarrow/array.pxi           | 34 ++++++++++++++++++---------------
 python/pyarrow/table.pxi           |  9 +++++++--
 python/pyarrow/tests/test_table.py | 11 +++++++++--
 3 files changed, 35 insertions(+), 19 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/arrow/blob/b4e9ba1a/python/pyarrow/array.pxi
----------------------------------------------------------------------
diff --git a/python/pyarrow/array.pxi b/python/pyarrow/array.pxi
index efbe36f..67418aa 100644
--- a/python/pyarrow/array.pxi
+++ b/python/pyarrow/array.pxi
@@ -89,6 +89,23 @@ def array(object sequence, DataType type=None, MemoryPool 
memory_pool=None,
     return pyarrow_wrap_array(sp_array)
 
 
+def _normalize_slice(object arrow_obj, slice key):
+    cdef Py_ssize_t n = len(arrow_obj)
+
+    start = key.start or 0
+    while start < 0:
+        start += n
+
+    stop = key.stop if key.stop is not None else n
+    while stop < 0:
+        stop += n
+
+    step = key.step or 1
+    if step != 1:
+        raise IndexError('only slices with step 1 supported')
+    else:
+        return arrow_obj.slice(start, stop - start)
+
 
 cdef class Array:
 
@@ -230,23 +247,10 @@ cdef class Array:
         raise NotImplemented
 
     def __getitem__(self, key):
-        cdef:
-            Py_ssize_t n = len(self)
+        cdef Py_ssize_t n = len(self)
 
         if PySlice_Check(key):
-            start = key.start or 0
-            while start < 0:
-                start += n
-
-            stop = key.stop if key.stop is not None else n
-            while stop < 0:
-                stop += n
-
-            step = key.step or 1
-            if step != 1:
-                raise IndexError('only slices with step 1 supported')
-            else:
-                return self.slice(start, stop - start)
+            return _normalize_slice(self, key)
 
         while key < 0:
             key += len(self)

http://git-wip-us.apache.org/repos/asf/arrow/blob/b4e9ba1a/python/pyarrow/table.pxi
----------------------------------------------------------------------
diff --git a/python/pyarrow/table.pxi b/python/pyarrow/table.pxi
index 6188e90..a9cb064 100644
--- a/python/pyarrow/table.pxi
+++ b/python/pyarrow/table.pxi
@@ -475,8 +475,13 @@ cdef class RecordBatch:
             )
         return pyarrow_wrap_array(self.batch.column(i))
 
-    def __getitem__(self, i):
-        return self.column(i)
+    def __getitem__(self, key):
+        cdef:
+            Py_ssize_t start, stop
+        if isinstance(key, slice):
+            return _normalize_slice(self, key)
+        else:
+            return self.column(key)
 
     def slice(self, offset=0, length=None):
         """

http://git-wip-us.apache.org/repos/asf/arrow/blob/b4e9ba1a/python/pyarrow/tests/test_table.py
----------------------------------------------------------------------
diff --git a/python/pyarrow/tests/test_table.py 
b/python/pyarrow/tests/test_table.py
index c2aeda9..28b98f0 100644
--- a/python/pyarrow/tests/test_table.py
+++ b/python/pyarrow/tests/test_table.py
@@ -80,7 +80,7 @@ def test_recordbatch_basics():
         batch[2]
 
 
-def test_recordbatch_slice():
+def test_recordbatch_slice_getitem():
     data = [
         pa.array(range(5)),
         pa.array([-10, -5, 0, 5, 10])
@@ -90,7 +90,6 @@ def test_recordbatch_slice():
     batch = pa.RecordBatch.from_arrays(data, names)
 
     sliced = batch.slice(2)
-
     assert sliced.num_rows == 3
 
     expected = pa.RecordBatch.from_arrays(
@@ -111,6 +110,14 @@ def test_recordbatch_slice():
     with pytest.raises(IndexError):
         batch.slice(-1)
 
+    # Check __getitem__-based slicing
+    assert batch.slice(0, 0).equals(batch[:0])
+    assert batch.slice(0, 2).equals(batch[:2])
+    assert batch.slice(2, 2).equals(batch[2:4])
+    assert batch.slice(2, len(batch) - 2).equals(batch[2:])
+    assert batch.slice(len(batch) - 2, 2).equals(batch[-2:])
+    assert batch.slice(len(batch) - 4, 2).equals(batch[-4:-2])
+
 
 def test_recordbatch_from_to_pandas():
     data = pd.DataFrame({

Reply via email to