This is an automated email from the ASF dual-hosted git repository.

brycemecum pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new 1969ae342a GH-46729: [Python] Allow constructing InMemoryDataset from 
RecordBatchReader (#46731)
1969ae342a is described below

commit 1969ae342af03ae6c6dba5d6706ae0fc61a18ac2
Author: Bryce Mecum <[email protected]>
AuthorDate: Fri Jun 6 11:37:43 2025 -0700

    GH-46729: [Python] Allow constructing InMemoryDataset from 
RecordBatchReader (#46731)
    
    ### Rationale for this change
    
    Our docs say you can construct a Dataset from a RecordBatchReader but you 
can't. While we can't pass the actual RecordBatchReader to the Dataset as a 
source (AFAIK), we can at least consume the reader immediately and create an 
InMemoryDataset from the batches.
    
    ### What changes are included in this PR?
    
    - Tweaked type checks so this now works (both from ds.dataset and 
ds.InMemoryDataset)
    - Test case extended to cover the new behavior
    - Tweaked error message just to use proper case
    
    ### Are these changes tested?
    
    Yes.
    
    ### Are there any user-facing changes?
    
    No.
    * GitHub Issue: #46729
    
    Authored-by: Bryce Mecum <[email protected]>
    Signed-off-by: Bryce Mecum <[email protected]>
---
 python/pyarrow/_dataset.pyx          | 6 +++---
 python/pyarrow/dataset.py            | 2 +-
 python/pyarrow/tests/test_dataset.py | 3 ++-
 3 files changed, 6 insertions(+), 5 deletions(-)

diff --git a/python/pyarrow/_dataset.pyx b/python/pyarrow/_dataset.pyx
index 9e5edee574..478c6b3f7c 100644
--- a/python/pyarrow/_dataset.pyx
+++ b/python/pyarrow/_dataset.pyx
@@ -1011,7 +1011,7 @@ cdef class InMemoryDataset(Dataset):
         if isinstance(source, (pa.RecordBatch, pa.Table)):
             source = [source]
 
-        if isinstance(source, (list, tuple)):
+        if isinstance(source, (list, tuple, pa.RecordBatchReader)):
             batches = []
             for item in source:
                 if isinstance(item, pa.RecordBatch):
@@ -1036,8 +1036,8 @@ cdef class InMemoryDataset(Dataset):
                 pyarrow_unwrap_table(table))
         else:
             raise TypeError(
-                'Expected a table, batch, or list of tables/batches '
-                'instead of the given type: ' +
+                'Expected a Table, RecordBatch, list of Table/RecordBatch, '
+                'or RecordBatchReader instead of the given type: ' +
                 type(source).__name__
             )
 
diff --git a/python/pyarrow/dataset.py b/python/pyarrow/dataset.py
index 26602c1e17..ef4f728872 100644
--- a/python/pyarrow/dataset.py
+++ b/python/pyarrow/dataset.py
@@ -804,7 +804,7 @@ RecordBatch or Table, iterable of RecordBatch, 
RecordBatchReader, or URI
                 'of batches or tables. The given list contains the following '
                 f'types: {type_names}'
             )
-    elif isinstance(source, (pa.RecordBatch, pa.Table)):
+    elif isinstance(source, (pa.RecordBatch, pa.Table, pa.RecordBatchReader)):
         return _in_memory_dataset(source, **kwargs)
     else:
         raise TypeError(
diff --git a/python/pyarrow/tests/test_dataset.py 
b/python/pyarrow/tests/test_dataset.py
index 4af0f914eb..c17e038713 100644
--- a/python/pyarrow/tests/test_dataset.py
+++ b/python/pyarrow/tests/test_dataset.py
@@ -2558,13 +2558,14 @@ def 
test_construct_from_invalid_sources_raise(multisourcefs):
 
 def test_construct_in_memory(dataset_reader):
     batch = pa.RecordBatch.from_arrays([pa.array(range(10))], names=["a"])
+    reader = pa.RecordBatchReader.from_batches(batch.schema, [batch])
     table = pa.Table.from_batches([batch])
 
     dataset_table = ds.dataset([], format='ipc', schema=pa.schema([])
                                ).to_table()
     assert dataset_table == pa.table([])
 
-    for source in (batch, table, [batch], [table]):
+    for source in (batch, table, [batch], [table], reader):
         dataset = ds.dataset(source)
         assert dataset_reader.to_table(dataset) == table
         assert len(list(dataset.get_fragments())) == 1

Reply via email to