This is an automated email from the ASF dual-hosted git repository.

jorisvandenbossche pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new 6cf5e8984a GH-33926: [Python] DataFrame Interchange Protocol for 
pyarrow.RecordBatch (#34294)
6cf5e8984a is described below

commit 6cf5e8984a48e70352e65112c3254fc03557767a
Author: Alenka Frim <[email protected]>
AuthorDate: Tue Feb 28 10:59:52 2023 +0100

    GH-33926: [Python] DataFrame Interchange Protocol for pyarrow.RecordBatch 
(#34294)
    
    ### Rationale for this change
    Add the implementation of the Dataframe Interchange Protocol for 
`pyarrow.RecordBatch`. The protocol is already implemented for pyarrow.Table, 
see https://github.com/apache/arrow/pull/14804.
    
    ### Are these changes tested?
    Yes, tests are added to:
    
    - python/pyarrow/tests/interchange/test_interchange_spec.py
    - python/pyarrow/tests/interchange/test_conversion.py
    * Closes: #33926
    
    Authored-by: Alenka Frim <[email protected]>
    Signed-off-by: Joris Van den Bossche <[email protected]>
---
 python/pyarrow/interchange/dataframe.py            | 51 ++++++++++++++--------
 python/pyarrow/interchange/from_dataframe.py       |  8 ++--
 python/pyarrow/table.pxi                           | 33 ++++++++++++++
 .../pyarrow/tests/interchange/test_conversion.py   |  3 +-
 .../tests/interchange/test_interchange_spec.py     | 41 +++++++++++++----
 5 files changed, 105 insertions(+), 31 deletions(-)

diff --git a/python/pyarrow/interchange/dataframe.py 
b/python/pyarrow/interchange/dataframe.py
index d0717e02e8..59ba765c17 100644
--- a/python/pyarrow/interchange/dataframe.py
+++ b/python/pyarrow/interchange/dataframe.py
@@ -44,11 +44,13 @@ class _PyArrowDataFrame:
     """
 
     def __init__(
-        self, df: pa.Table, nan_as_null: bool = False, allow_copy: bool = True
+        self, df: pa.Table | pa.RecordBatch,
+        nan_as_null: bool = False,
+        allow_copy: bool = True
     ) -> None:
         """
         Constructor - an instance of this (private) class is returned from
-        `pa.Table.__dataframe__`.
+        `pa.Table.__dataframe__` or `pa.RecordBatch.__dataframe__`.
         """
         self._df = df
         # ``nan_as_null`` is a keyword intended for the consumer to tell the
@@ -114,18 +116,21 @@ class _PyArrowDataFrame:
         """
         Return the number of chunks the DataFrame consists of.
         """
-        # pyarrow.Table can have columns with different number
-        # of chunks so we take the number of chunks that
-        # .to_batches() returns as it takes the min chunk size
-        # of all the columns (to_batches is a zero copy method)
-        batches = self._df.to_batches()
-        return len(batches)
+        if isinstance(self._df, pa.RecordBatch):
+            return 1
+        else:
+            # pyarrow.Table can have columns with different number
+            # of chunks so we take the number of chunks that
+            # .to_batches() returns as it takes the min chunk size
+            # of all the columns (to_batches is a zero copy method)
+            batches = self._df.to_batches()
+            return len(batches)
 
     def column_names(self) -> Iterable[str]:
         """
         Return an iterator yielding the column names.
         """
-        return self._df.column_names
+        return self._df.schema.names
 
     def get_column(self, i: int) -> _PyArrowColumn:
         """
@@ -182,21 +187,31 @@ class _PyArrowDataFrame:
         Note that the producer must ensure that all columns are chunked the
         same way.
         """
+        # Subdivide chunks
         if n_chunks and n_chunks > 1:
             chunk_size = self.num_rows() // n_chunks
             if self.num_rows() % n_chunks != 0:
                 chunk_size += 1
-            batches = self._df.to_batches(max_chunksize=chunk_size)
+            if isinstance(self._df, pa.Table):
+                batches = self._df.to_batches(max_chunksize=chunk_size)
+            else:
+                batches = []
+                for start in range(0, chunk_size * n_chunks, chunk_size):
+                    batches.append(self._df.slice(start, chunk_size))
             # In case when the size of the chunk is such that the resulting
             # list is one less chunk then n_chunks -> append an empty chunk
             if len(batches) == n_chunks - 1:
                 batches.append(pa.record_batch([[]], schema=self._df.schema))
+        # yields the chunks that the data is stored as
         else:
-            batches = self._df.to_batches()
-
-        iterator_tables = [_PyArrowDataFrame(
-            pa.Table.from_batches([batch]), self._nan_as_null, self._allow_copy
-        )
-            for batch in batches
-        ]
-        return iterator_tables
+            if isinstance(self._df, pa.Table):
+                batches = self._df.to_batches()
+            else:
+                batches = [self._df]
+
+        # Create an iterator of RecordBatches
+        iterator = [_PyArrowDataFrame(batch,
+                                      self._nan_as_null,
+                                      self._allow_copy)
+                    for batch in batches]
+        return iterator
diff --git a/python/pyarrow/interchange/from_dataframe.py 
b/python/pyarrow/interchange/from_dataframe.py
index 204530a335..801d0dd452 100644
--- a/python/pyarrow/interchange/from_dataframe.py
+++ b/python/pyarrow/interchange/from_dataframe.py
@@ -60,8 +60,7 @@ _PYARROW_DTYPES: dict[DtypeKind, dict[int, Any]] = {
 
 def from_dataframe(df: DataFrameObject, allow_copy=True) -> pa.Table:
     """
-    Build a ``pa.Table`` from any DataFrame supporting the interchange
-    protocol.
+    Build a ``pa.Table`` from any DataFrame supporting the interchange 
protocol.
 
     Parameters
     ----------
@@ -78,6 +77,8 @@ def from_dataframe(df: DataFrameObject, allow_copy=True) -> 
pa.Table:
     """
     if isinstance(df, pa.Table):
         return df
+    elif isinstance(df, pa.RecordBatch):
+        return pa.Table.from_batches([df])
 
     if not hasattr(df, "__dataframe__"):
         raise ValueError("`df` does not support __dataframe__")
@@ -108,8 +109,7 @@ def _from_dataframe(df: DataFrameObject, allow_copy=True):
         batch = protocol_df_chunk_to_pyarrow(chunk, allow_copy)
         batches.append(batch)
 
-    table = pa.Table.from_batches(batches)
-    return table
+    return pa.Table.from_batches(batches)
 
 
 def protocol_df_chunk_to_pyarrow(
diff --git a/python/pyarrow/table.pxi b/python/pyarrow/table.pxi
index 5b7989e582..e400605e56 100644
--- a/python/pyarrow/table.pxi
+++ b/python/pyarrow/table.pxi
@@ -1517,6 +1517,39 @@ cdef class RecordBatch(_PandasConvertible):
         self.sp_batch = batch
         self.batch = batch.get()
 
+    # ----------------------------------------------------------------------
+    def __dataframe__(self, nan_as_null: bool = False, allow_copy: bool = 
True):
+        """
+        Return the dataframe interchange object implementing the interchange 
protocol.
+
+        Parameters
+        ----------
+        nan_as_null : bool, default False
+            Whether to tell the DataFrame to overwrite null values in the data
+            with ``NaN`` (or ``NaT``).
+        allow_copy : bool, default True
+            Whether to allow memory copying when exporting. If set to False
+            it would cause non-zero-copy exports to fail.
+
+        Returns
+        -------
+        DataFrame interchange object
+            The object which consuming library can use to ingress the 
dataframe.
+
+        Notes
+        -----
+        Details on the interchange protocol:
+        https://data-apis.org/dataframe-protocol/latest/index.html
+        `nan_as_null` currently has no effect; once support for nullable 
extension
+        dtypes is added, this value should be propagated to columns.
+        """
+
+        from pyarrow.interchange.dataframe import _PyArrowDataFrame
+
+        return _PyArrowDataFrame(self, nan_as_null, allow_copy)
+
+    # ----------------------------------------------------------------------
+
     @staticmethod
     def from_pydict(mapping, schema=None, metadata=None):
         """
diff --git a/python/pyarrow/tests/interchange/test_conversion.py 
b/python/pyarrow/tests/interchange/test_conversion.py
index 0680d9c4ec..089f316e50 100644
--- a/python/pyarrow/tests/interchange/test_conversion.py
+++ b/python/pyarrow/tests/interchange/test_conversion.py
@@ -108,6 +108,7 @@ def test_categorical_roundtrip():
 
     if Version(pd.__version__) < Version("1.5.0"):
         pytest.skip("__dataframe__ added to pandas in 1.5.0")
+
     arr = ["Mon", "Tue", "Mon", "Wed", "Mon", "Thu", "Fri", "Sat", "Sun"]
     table = pa.table(
         {"weekday": pa.array(arr).dictionary_encode()}
@@ -447,7 +448,7 @@ def test_pyarrow_roundtrip_categorical(offset, length):
     assert col_result.size() == col_table.size()
     assert col_result.offset == col_table.offset
 
-    desc_cat_table = col_result.describe_categorical
+    desc_cat_table = col_table.describe_categorical
     desc_cat_result = col_result.describe_categorical
 
     assert desc_cat_table["is_ordered"] == desc_cat_result["is_ordered"]
diff --git a/python/pyarrow/tests/interchange/test_interchange_spec.py 
b/python/pyarrow/tests/interchange/test_interchange_spec.py
index 42ec805359..7b2b8eb720 100644
--- a/python/pyarrow/tests/interchange/test_interchange_spec.py
+++ b/python/pyarrow/tests/interchange/test_interchange_spec.py
@@ -76,8 +76,10 @@ def test_dtypes(arr):
 )
 @pytest.mark.parametrize("unit", ['s', 'ms', 'us', 'ns'])
 @pytest.mark.parametrize("tz", ['', 'America/New_York', '+07:30', '-04:30'])
[email protected]("use_batch", [False, True])
 def test_mixed_dtypes(uint, uint_bw, int, int_bw,
-                      float, float_bw, np_float, unit, tz):
+                      float, float_bw, np_float, unit, tz,
+                      use_batch):
     from datetime import datetime as dt
     arr = [1, 2, 3]
     dt_arr = [dt(2007, 7, 13), dt(2007, 7, 14), dt(2007, 7, 15)]
@@ -91,6 +93,8 @@ def test_mixed_dtypes(uint, uint_bw, int, int_bw,
             "f": pa.array(dt_arr, type=pa.timestamp(unit, tz=tz))
         }
     )
+    if use_batch:
+        table = table.to_batches()[0]
     df = table.__dataframe__()
     # 0 = DtypeKind.INT, 1 = DtypeKind.UINT, 2 = DtypeKind.FLOAT,
     # 20 = DtypeKind.BOOL, 21 = DtypeKind.STRING, 22 = DtypeKind.DATETIME
@@ -126,12 +130,15 @@ def test_noncategorical():
         col.describe_categorical
 
 
-def test_categorical():
[email protected]("use_batch", [False, True])
+def test_categorical(use_batch):
     import pyarrow as pa
     arr = ["Mon", "Tue", "Mon", "Wed", "Mon", "Thu", "Fri", "Sat", None]
     table = pa.table(
         {"weekday": pa.array(arr).dictionary_encode()}
     )
+    if use_batch:
+        table = table.to_batches()[0]
 
     col = table.__dataframe__().get_column_by_name("weekday")
     categorical = col.describe_categorical
@@ -139,34 +146,46 @@ def test_categorical():
     assert isinstance(categorical["is_dictionary"], bool)
 
 
-def test_dataframe():
[email protected]("use_batch", [False, True])
+def test_dataframe(use_batch):
     n = pa.chunked_array([[2, 2, 4], [4, 5, 100]])
     a = pa.chunked_array([["Flamingo", "Parrot", "Cow"],
                          ["Horse", "Brittle stars", "Centipede"]])
     table = pa.table([n, a], names=['n_legs', 'animals'])
+    if use_batch:
+        table = table.combine_chunks().to_batches()[0]
     df = table.__dataframe__()
 
     assert df.num_columns() == 2
     assert df.num_rows() == 6
-    assert df.num_chunks() == 2
+    if use_batch:
+        assert df.num_chunks() == 1
+    else:
+        assert df.num_chunks() == 2
     assert list(df.column_names()) == ['n_legs', 'animals']
     assert list(df.select_columns((1,)).column_names()) == list(
         df.select_columns_by_name(("animals",)).column_names()
     )
 
 
[email protected]("use_batch", [False, True])
 @pytest.mark.parametrize(["size", "n_chunks"], [(10, 3), (12, 3), (12, 5)])
-def test_df_get_chunks(size, n_chunks):
+def test_df_get_chunks(use_batch, size, n_chunks):
     table = pa.table({"x": list(range(size))})
+    if use_batch:
+        table = table.to_batches()[0]
     df = table.__dataframe__()
     chunks = list(df.get_chunks(n_chunks))
     assert len(chunks) == n_chunks
     assert sum(chunk.num_rows() for chunk in chunks) == size
 
 
[email protected]("use_batch", [False, True])
 @pytest.mark.parametrize(["size", "n_chunks"], [(10, 3), (12, 3), (12, 5)])
-def test_column_get_chunks(size, n_chunks):
+def test_column_get_chunks(use_batch, size, n_chunks):
     table = pa.table({"x": list(range(size))})
+    if use_batch:
+        table = table.to_batches()[0]
     df = table.__dataframe__()
     chunks = list(df.get_column(0).get_chunks(n_chunks))
     assert len(chunks) == n_chunks
@@ -187,7 +206,8 @@ def test_column_get_chunks(size, n_chunks):
         (pa.float64(), np.float64)
     ]
 )
-def test_get_columns(uint, int, float, np_float):
[email protected]("use_batch", [False, True])
+def test_get_columns(uint, int, float, np_float, use_batch):
     arr = [[1, 2, 3], [4, 5]]
     arr_float = np.array([1, 2, 3, 4, 5], dtype=np_float)
     table = pa.table(
@@ -197,6 +217,8 @@ def test_get_columns(uint, int, float, np_float):
             "c": pa.array(arr_float, type=float)
         }
     )
+    if use_batch:
+        table = table.combine_chunks().to_batches()[0]
     df = table.__dataframe__()
     for col in df.get_columns():
         assert col.size() == 5
@@ -212,9 +234,12 @@ def test_get_columns(uint, int, float, np_float):
 @pytest.mark.parametrize(
     "int", [pa.int8(), pa.int16(), pa.int32(), pa.int64()]
 )
-def test_buffer(int):
[email protected]("use_batch", [False, True])
+def test_buffer(int, use_batch):
     arr = [0, 1, -1]
     table = pa.table({"a": pa.array(arr, type=int)})
+    if use_batch:
+        table = table.to_batches()[0]
     df = table.__dataframe__()
     col = df.get_column(0)
     buf = col.get_buffers()

Reply via email to