This is an automated email from the ASF dual-hosted git repository.
amolina pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new d55b1af7b7 GH-44125: [Python] Add concat_batches function (#44126)
d55b1af7b7 is described below
commit d55b1af7b78ed2210ad9705d7484b38f1744f37b
Author: Alessandro Molina <[email protected]>
AuthorDate: Wed Oct 16 23:46:20 2024 +0200
GH-44125: [Python] Add concat_batches function (#44126)
### Rationale for this change
Allows to concatenate recordbatches in Python
### What changes are included in this PR?
Adds `concat_batches` function and tests
### Are these changes tested?
yes
### Are there any user-facing changes?
A new public function has been added
* GitHub Issue: #44125
---------
Co-authored-by: Joris Van den Bossche <[email protected]>
---
docs/source/python/api/tables.rst | 1 +
python/pyarrow/__init__.py | 2 +-
python/pyarrow/includes/libarrow.pxd | 4 +++
python/pyarrow/table.pxi | 51 ++++++++++++++++++++++++++++++++++++
python/pyarrow/tests/test_table.py | 43 ++++++++++++++++++++++++++++++
5 files changed, 100 insertions(+), 1 deletion(-)
diff --git a/docs/source/python/api/tables.rst
b/docs/source/python/api/tables.rst
index ae9f5de127..48cc67eb66 100644
--- a/docs/source/python/api/tables.rst
+++ b/docs/source/python/api/tables.rst
@@ -32,6 +32,7 @@ Factory Functions
concat_arrays
concat_tables
record_batch
+ concat_batches
table
Classes
diff --git a/python/pyarrow/__init__.py b/python/pyarrow/__init__.py
index d31c93119b..fb7c242187 100644
--- a/python/pyarrow/__init__.py
+++ b/python/pyarrow/__init__.py
@@ -267,7 +267,7 @@ from pyarrow.lib import (NativeFile, PythonFile,
from pyarrow.lib import (ChunkedArray, RecordBatch, Table, table,
concat_arrays, concat_tables, TableGroupBy,
- RecordBatchReader)
+ RecordBatchReader, concat_batches)
# Exceptions
from pyarrow.lib import (ArrowCancelled,
diff --git a/python/pyarrow/includes/libarrow.pxd
b/python/pyarrow/includes/libarrow.pxd
index 8e6922a912..d304641e0f 100644
--- a/python/pyarrow/includes/libarrow.pxd
+++ b/python/pyarrow/includes/libarrow.pxd
@@ -1356,6 +1356,10 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil:
CConcatenateTablesOptions options,
CMemoryPool* memory_pool)
+ CResult[shared_ptr[CRecordBatch]] ConcatenateRecordBatches(
+ const vector[shared_ptr[CRecordBatch]]& batches,
+ CMemoryPool* memory_pool)
+
cdef cppclass CDictionaryUnifier" arrow::DictionaryUnifier":
@staticmethod
CResult[shared_ptr[CChunkedArray]] UnifyChunkedArray(
diff --git a/python/pyarrow/table.pxi b/python/pyarrow/table.pxi
index 819bbc34c6..af241e4be0 100644
--- a/python/pyarrow/table.pxi
+++ b/python/pyarrow/table.pxi
@@ -6259,6 +6259,57 @@ def concat_tables(tables, MemoryPool memory_pool=None,
str promote_options="none
return pyarrow_wrap_table(c_result_table)
+def concat_batches(recordbatches, MemoryPool memory_pool=None):
+ """
+ Concatenate pyarrow.RecordBatch objects.
+
+ All recordbatches must share the same Schema,
+ the operation implies a copy of the data to merge
+ the arrays of the different RecordBatches.
+
+ Parameters
+ ----------
+ recordbatches : iterable of pyarrow.RecordBatch objects
+ Pyarrow record batches to concatenate into a single RecordBatch.
+ memory_pool : MemoryPool, default None
+ For memory allocations, if required, otherwise use default pool.
+
+ Examples
+ --------
+ >>> import pyarrow as pa
+ >>> t1 = pa.record_batch([
+ ... pa.array([2, 4, 5, 100]),
+ ... pa.array(["Flamingo", "Horse", "Brittle stars", "Centipede"])
+ ... ], names=['n_legs', 'animals'])
+ >>> t2 = pa.record_batch([
+ ... pa.array([2, 4]),
+ ... pa.array(["Parrot", "Dog"])
+ ... ], names=['n_legs', 'animals'])
+ >>> pa.concat_batches([t1,t2])
+ pyarrow.RecordBatch
+ n_legs: int64
+ animals: string
+ ----
+ n_legs: [2,4,5,100,2,4]
+ animals: ["Flamingo","Horse","Brittle stars","Centipede","Parrot","Dog"]
+
+ """
+ cdef:
+ vector[shared_ptr[CRecordBatch]] c_recordbatches
+ shared_ptr[CRecordBatch] c_result_recordbatch
+ RecordBatch recordbatch
+ CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool)
+
+ for recordbatch in recordbatches:
+ c_recordbatches.push_back(recordbatch.sp_batch)
+
+ with nogil:
+ c_result_recordbatch = GetResultValue(
+ ConcatenateRecordBatches(c_recordbatches, pool))
+
+ return pyarrow_wrap_batch(c_result_recordbatch)
+
+
def _from_pydict(cls, mapping, schema, metadata):
"""
Construct a Table/RecordBatch from Arrow arrays or columns.
diff --git a/python/pyarrow/tests/test_table.py
b/python/pyarrow/tests/test_table.py
index b66a5eb083..4c058ccecd 100644
--- a/python/pyarrow/tests/test_table.py
+++ b/python/pyarrow/tests/test_table.py
@@ -2037,6 +2037,49 @@ def test_table_negative_indexing():
table[4]
+def test_concat_batches():
+ data = [
+ list(range(5)),
+ [-10., -5., 0., 5., 10.]
+ ]
+ data2 = [
+ list(range(5, 10)),
+ [1., 2., 3., 4., 5.]
+ ]
+
+ t1 = pa.RecordBatch.from_arrays([pa.array(x) for x in data],
+ names=('a', 'b'))
+ t2 = pa.RecordBatch.from_arrays([pa.array(x) for x in data2],
+ names=('a', 'b'))
+
+ result = pa.concat_batches([t1, t2])
+ result.validate()
+ assert len(result) == 10
+
+ expected = pa.RecordBatch.from_arrays([pa.array(x + y)
+ for x, y in zip(data, data2)],
+ names=('a', 'b'))
+
+ assert result.equals(expected)
+
+
+def test_concat_batches_different_schema():
+ t1 = pa.RecordBatch.from_arrays(
+ [pa.array([1, 2], type=pa.int64())], ["f"])
+ t2 = pa.RecordBatch.from_arrays(
+ [pa.array([1, 2], type=pa.float32())], ["f"])
+
+ with pytest.raises(pa.ArrowInvalid,
+ match="not match index 0 recordbatch schema"):
+ pa.concat_batches([t1, t2])
+
+
+def test_concat_batches_none_batches():
+ # ARROW-11997
+ with pytest.raises(AttributeError):
+ pa.concat_batches([None])
+
+
@pytest.mark.parametrize(
('cls'),
[