This is an automated email from the ASF dual-hosted git repository.

amolina pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new d55b1af7b7 GH-44125: [Python] Add concat_batches function (#44126)
d55b1af7b7 is described below

commit d55b1af7b78ed2210ad9705d7484b38f1744f37b
Author: Alessandro Molina <[email protected]>
AuthorDate: Wed Oct 16 23:46:20 2024 +0200

    GH-44125: [Python] Add concat_batches function (#44126)
    
    ### Rationale for this change
    
    Allows to concatenate recordbatches in Python
    
    ### What changes are included in this PR?
    
    Adds `concat_batches` function and tests
    
    ### Are these changes tested?
    
    yes
    
    ### Are there any user-facing changes?
    
    A new public function has been added
    * GitHub Issue: #44125
    
    ---------
    
    Co-authored-by: Joris Van den Bossche <[email protected]>
---
 docs/source/python/api/tables.rst    |  1 +
 python/pyarrow/__init__.py           |  2 +-
 python/pyarrow/includes/libarrow.pxd |  4 +++
 python/pyarrow/table.pxi             | 51 ++++++++++++++++++++++++++++++++++++
 python/pyarrow/tests/test_table.py   | 43 ++++++++++++++++++++++++++++++
 5 files changed, 100 insertions(+), 1 deletion(-)

diff --git a/docs/source/python/api/tables.rst 
b/docs/source/python/api/tables.rst
index ae9f5de127..48cc67eb66 100644
--- a/docs/source/python/api/tables.rst
+++ b/docs/source/python/api/tables.rst
@@ -32,6 +32,7 @@ Factory Functions
    concat_arrays
    concat_tables
    record_batch
+   concat_batches
    table
 
 Classes
diff --git a/python/pyarrow/__init__.py b/python/pyarrow/__init__.py
index d31c93119b..fb7c242187 100644
--- a/python/pyarrow/__init__.py
+++ b/python/pyarrow/__init__.py
@@ -267,7 +267,7 @@ from pyarrow.lib import (NativeFile, PythonFile,
 
 from pyarrow.lib import (ChunkedArray, RecordBatch, Table, table,
                          concat_arrays, concat_tables, TableGroupBy,
-                         RecordBatchReader)
+                         RecordBatchReader, concat_batches)
 
 # Exceptions
 from pyarrow.lib import (ArrowCancelled,
diff --git a/python/pyarrow/includes/libarrow.pxd 
b/python/pyarrow/includes/libarrow.pxd
index 8e6922a912..d304641e0f 100644
--- a/python/pyarrow/includes/libarrow.pxd
+++ b/python/pyarrow/includes/libarrow.pxd
@@ -1356,6 +1356,10 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil:
         CConcatenateTablesOptions options,
         CMemoryPool* memory_pool)
 
+    CResult[shared_ptr[CRecordBatch]] ConcatenateRecordBatches(
+        const vector[shared_ptr[CRecordBatch]]& batches,
+        CMemoryPool* memory_pool)
+
     cdef cppclass CDictionaryUnifier" arrow::DictionaryUnifier":
         @staticmethod
         CResult[shared_ptr[CChunkedArray]] UnifyChunkedArray(
diff --git a/python/pyarrow/table.pxi b/python/pyarrow/table.pxi
index 819bbc34c6..af241e4be0 100644
--- a/python/pyarrow/table.pxi
+++ b/python/pyarrow/table.pxi
@@ -6259,6 +6259,57 @@ def concat_tables(tables, MemoryPool memory_pool=None, 
str promote_options="none
     return pyarrow_wrap_table(c_result_table)
 
 
+def concat_batches(recordbatches, MemoryPool memory_pool=None):
+    """
+    Concatenate pyarrow.RecordBatch objects.
+
+    All recordbatches must share the same Schema,
+    the operation implies a copy of the data to merge
+    the arrays of the different RecordBatches.
+
+    Parameters
+    ----------
+    recordbatches : iterable of pyarrow.RecordBatch objects
+        Pyarrow record batches to concatenate into a single RecordBatch.
+    memory_pool : MemoryPool, default None
+        For memory allocations, if required, otherwise use default pool.
+
+    Examples
+    --------
+    >>> import pyarrow as pa
+    >>> t1 = pa.record_batch([
+    ...     pa.array([2, 4, 5, 100]),
+    ...     pa.array(["Flamingo", "Horse", "Brittle stars", "Centipede"])
+    ...     ], names=['n_legs', 'animals'])
+    >>> t2 = pa.record_batch([
+    ...     pa.array([2, 4]),
+    ...     pa.array(["Parrot", "Dog"])
+    ...     ], names=['n_legs', 'animals'])
+    >>> pa.concat_batches([t1,t2])
+    pyarrow.RecordBatch
+    n_legs: int64
+    animals: string
+    ----
+    n_legs: [2,4,5,100,2,4]
+    animals: ["Flamingo","Horse","Brittle stars","Centipede","Parrot","Dog"]
+
+    """
+    cdef:
+        vector[shared_ptr[CRecordBatch]] c_recordbatches
+        shared_ptr[CRecordBatch] c_result_recordbatch
+        RecordBatch recordbatch
+        CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool)
+
+    for recordbatch in recordbatches:
+        c_recordbatches.push_back(recordbatch.sp_batch)
+
+    with nogil:
+        c_result_recordbatch = GetResultValue(
+            ConcatenateRecordBatches(c_recordbatches, pool))
+
+    return pyarrow_wrap_batch(c_result_recordbatch)
+
+
 def _from_pydict(cls, mapping, schema, metadata):
     """
     Construct a Table/RecordBatch from Arrow arrays or columns.
diff --git a/python/pyarrow/tests/test_table.py 
b/python/pyarrow/tests/test_table.py
index b66a5eb083..4c058ccecd 100644
--- a/python/pyarrow/tests/test_table.py
+++ b/python/pyarrow/tests/test_table.py
@@ -2037,6 +2037,49 @@ def test_table_negative_indexing():
         table[4]
 
 
+def test_concat_batches():
+    data = [
+        list(range(5)),
+        [-10., -5., 0., 5., 10.]
+    ]
+    data2 = [
+        list(range(5, 10)),
+        [1., 2., 3., 4., 5.]
+    ]
+
+    t1 = pa.RecordBatch.from_arrays([pa.array(x) for x in data],
+                                    names=('a', 'b'))
+    t2 = pa.RecordBatch.from_arrays([pa.array(x) for x in data2],
+                                    names=('a', 'b'))
+
+    result = pa.concat_batches([t1, t2])
+    result.validate()
+    assert len(result) == 10
+
+    expected = pa.RecordBatch.from_arrays([pa.array(x + y)
+                                           for x, y in zip(data, data2)],
+                                          names=('a', 'b'))
+
+    assert result.equals(expected)
+
+
+def test_concat_batches_different_schema():
+    t1 = pa.RecordBatch.from_arrays(
+        [pa.array([1, 2], type=pa.int64())], ["f"])
+    t2 = pa.RecordBatch.from_arrays(
+        [pa.array([1, 2], type=pa.float32())], ["f"])
+
+    with pytest.raises(pa.ArrowInvalid,
+                       match="not match index 0 recordbatch schema"):
+        pa.concat_batches([t1, t2])
+
+
+def test_concat_batches_none_batches():
+    # ARROW-11997
+    with pytest.raises(AttributeError):
+        pa.concat_batches([None])
+
+
 @pytest.mark.parametrize(
     ('cls'),
     [

Reply via email to