This is an automated email from the ASF dual-hosted git repository.

apitrou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new ad752482f8 GH-43410: [Python] Support Arrow PyCapsule stream objects 
in write_dataset (#43771)
ad752482f8 is described below

commit ad752482f819df638d9438feff6cad3c49946fc7
Author: Joris Van den Bossche <[email protected]>
AuthorDate: Mon Nov 18 16:56:09 2024 +0100

    GH-43410: [Python] Support Arrow PyCapsule stream objects in write_dataset 
(#43771)
    
    ### Rationale for this change
    
    Expanding the support internally in pyarrow where we accept objects 
implementing the Arrow PyCapsule interface. This PR adds support in 
`ds.write_dataset()` since we already accept a RecordBatchReader as well.
    
    ### What changes are included in this PR?
    
    `ds.write_dataset()` and `ds.Scanner.from_baches()` now accept any object 
implementing the Arrow PyCapsule interface for streams.
    
    ### Are these changes tested?
    
    Yes
    
    ### Are there any user-facing changes?
    
    No
    * GitHub Issue: #43410
    
    Authored-by: Joris Van den Bossche <[email protected]>
    Signed-off-by: Antoine Pitrou <[email protected]>
---
 python/pyarrow/_dataset.pyx          | 15 ++++++++++++---
 python/pyarrow/dataset.py            |  6 +++++-
 python/pyarrow/tests/test_dataset.py | 20 +++++++++++++++++---
 3 files changed, 34 insertions(+), 7 deletions(-)

diff --git a/python/pyarrow/_dataset.pyx b/python/pyarrow/_dataset.pyx
index 3a4fa1ab61..fd50215cee 100644
--- a/python/pyarrow/_dataset.pyx
+++ b/python/pyarrow/_dataset.pyx
@@ -3716,10 +3716,13 @@ cdef class Scanner(_Weakrefable):
 
         Parameters
         ----------
-        source : Iterator
-            The iterator of Batches.
+        source : Iterator or Arrow-compatible stream object
+            The iterator of Batches. This can be a pyarrow RecordBatchReader,
+            any object that implements the Arrow PyCapsule Protocol for
+            streams, or an actual Python iterator of RecordBatches.
         schema : Schema
-            The schema of the batches.
+            The schema of the batches (required when passing a Python
+            iterator).
         columns : list[str] or dict[str, Expression], default None
             The columns to project. This can be a list of column names to
             include (order and duplicates will be preserved), or a dictionary
@@ -3775,6 +3778,12 @@ cdef class Scanner(_Weakrefable):
                 raise ValueError('Cannot specify a schema when providing '
                                  'a RecordBatchReader')
             reader = source
+        elif hasattr(source, "__arrow_c_stream__"):
+            if schema:
+                raise ValueError(
+                    'Cannot specify a schema when providing an object '
+                    'implementing the Arrow PyCapsule Protocol')
+            reader = pa.ipc.RecordBatchReader.from_stream(source)
         elif _is_iterable(source):
             if schema is None:
                 raise ValueError('Must provide schema to construct scanner '
diff --git a/python/pyarrow/dataset.py b/python/pyarrow/dataset.py
index 1efbfe1665..c61e13ee75 100644
--- a/python/pyarrow/dataset.py
+++ b/python/pyarrow/dataset.py
@@ -964,7 +964,11 @@ Table/RecordBatch, or iterable of RecordBatch
     elif isinstance(data, (pa.RecordBatch, pa.Table)):
         schema = schema or data.schema
         data = InMemoryDataset(data, schema=schema)
-    elif isinstance(data, pa.ipc.RecordBatchReader) or _is_iterable(data):
+    elif (
+        isinstance(data, pa.ipc.RecordBatchReader)
+        or hasattr(data, "__arrow_c_stream__")
+        or _is_iterable(data)
+    ):
         data = Scanner.from_batches(data, schema=schema)
         schema = None
     elif not isinstance(data, (Dataset, Scanner)):
diff --git a/python/pyarrow/tests/test_dataset.py 
b/python/pyarrow/tests/test_dataset.py
index 772670ad79..b6aaa2840d 100644
--- a/python/pyarrow/tests/test_dataset.py
+++ b/python/pyarrow/tests/test_dataset.py
@@ -66,6 +66,14 @@ except ImportError:
 pytestmark = pytest.mark.dataset
 
 
+class TableStreamWrapper:
+    def __init__(self, table):
+        self.table = table
+
+    def __arrow_c_stream__(self, requested_schema=None):
+        return self.table.__arrow_c_stream__(requested_schema)
+
+
 def _generate_data(n):
     import datetime
     import itertools
@@ -2543,6 +2551,7 @@ def test_scan_iterator(use_threads):
     for factory, schema in (
             (lambda: pa.RecordBatchReader.from_batches(
                 batch.schema, [batch]), None),
+            (lambda: TableStreamWrapper(table), None),
             (lambda: (batch for _ in range(1)), batch.schema),
     ):
         # Scanning the fragment consumes the underlying iterator
@@ -4674,15 +4683,20 @@ def test_write_iterable(tempdir):
     base_dir = tempdir / 'inmemory_iterable'
     ds.write_dataset((batch for batch in table.to_batches()), base_dir,
                      schema=table.schema,
-                     basename_template='dat_{i}.arrow', format="feather")
+                     basename_template='dat_{i}.arrow', format="ipc")
     result = ds.dataset(base_dir, format="ipc").to_table()
     assert result.equals(table)
 
     base_dir = tempdir / 'inmemory_reader'
     reader = pa.RecordBatchReader.from_batches(table.schema,
                                                table.to_batches())
-    ds.write_dataset(reader, base_dir,
-                     basename_template='dat_{i}.arrow', format="feather")
+    ds.write_dataset(reader, base_dir, basename_template='dat_{i}.arrow', 
format="ipc")
+    result = ds.dataset(base_dir, format="ipc").to_table()
+    assert result.equals(table)
+
+    base_dir = tempdir / 'inmemory_pycapsule'
+    stream = TableStreamWrapper(table)
+    ds.write_dataset(stream, base_dir, basename_template='dat_{i}.arrow', 
format="ipc")
     result = ds.dataset(base_dir, format="ipc").to_table()
     assert result.equals(table)
 

Reply via email to