This is an automated email from the ASF dual-hosted git repository.

jorisvandenbossche pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new e05f032  ARROW-7800 [Python] implement iter_batches() method for 
ParquetFile and ParquetReader
e05f032 is described below

commit e05f032c1e5d590ac56372d13ec637bd28b47a96
Author: Will Jones <[email protected]>
AuthorDate: Wed Jan 6 11:16:14 2021 +0100

    ARROW-7800 [Python] implement iter_batches() method for ParquetFile and 
ParquetReader
    
    Implements an `iter_batches()` method for `ParquetFile` and 
`ParquetReader`. This is an attempt to complete the work started in #6386 by 
@rollokb.
    
    I wasn't sure how to set the `batch_size` parameter in 
`ArrowReaderProperties` after the init, so I ended up expanding the C++ API so 
that `GetRecordBatchReader` took a `batch_size` parameter. I know nearly 
nothing about C++ and Cython, so if there's a better way I'd love to learn. But 
I saw [Wes' comment in the PR creating those methods suggesting you might want 
this](https://github.com/apache/arrow/pull/4304#pullrequestreview-253163586), 
so I thought there's a chance I'm not overste [...]
    
    Closes #6979 from ghost/ARROW-7800/expose-get-record-batch-reader
    
    Lead-authored-by: Will Jones <[email protected]>
    Co-authored-by: Will Jones <[email protected]>
    Co-authored-by: Will Jones <[email protected]>
    Signed-off-by: Joris Van den Bossche <[email protected]>
---
 cpp/src/parquet/arrow/reader.cc                   |  4 ++
 cpp/src/parquet/arrow/reader.h                    |  3 +
 python/pyarrow/_parquet.pxd                       | 12 +++-
 python/pyarrow/_parquet.pyx                       | 49 ++++++++++++++++
 python/pyarrow/parquet.py                         | 38 ++++++++++++
 python/pyarrow/tests/parquet/test_parquet_file.py | 70 +++++++++++++++++++++++
 6 files changed, 174 insertions(+), 2 deletions(-)

diff --git a/cpp/src/parquet/arrow/reader.cc b/cpp/src/parquet/arrow/reader.cc
index 0f4e218..3cc3070 100644
--- a/cpp/src/parquet/arrow/reader.cc
+++ b/cpp/src/parquet/arrow/reader.cc
@@ -322,6 +322,10 @@ class FileReaderImpl : public FileReader {
     reader_properties_.set_use_threads(use_threads);
   }
 
+  void set_batch_size(int64_t batch_size) override {
+    reader_properties_.set_batch_size(batch_size);
+  }
+
   const ArrowReaderProperties& properties() const override { return 
reader_properties_; }
 
   const SchemaManifest& manifest() const override { return manifest_; }
diff --git a/cpp/src/parquet/arrow/reader.h b/cpp/src/parquet/arrow/reader.h
index 8c1c73b..4e75b25 100644
--- a/cpp/src/parquet/arrow/reader.h
+++ b/cpp/src/parquet/arrow/reader.h
@@ -218,6 +218,9 @@ class PARQUET_EXPORT FileReader {
   /// By default only one thread is used.
   virtual void set_use_threads(bool use_threads) = 0;
 
+  /// Set number of records to read per batch for the RecordBatchReader.
+  virtual void set_batch_size(int64_t batch_size) = 0;
+
   virtual const ArrowReaderProperties& properties() const = 0;
 
   virtual const SchemaManifest& manifest() const = 0;
diff --git a/python/pyarrow/_parquet.pxd b/python/pyarrow/_parquet.pxd
index f1c0abf..8c0f5a9 100644
--- a/python/pyarrow/_parquet.pxd
+++ b/python/pyarrow/_parquet.pxd
@@ -23,7 +23,7 @@ from pyarrow.includes.libarrow cimport (CChunkedArray, 
CSchema, CStatus,
                                         CTable, CMemoryPool, CBuffer,
                                         CKeyValueMetadata,
                                         CRandomAccessFile, COutputStream,
-                                        TimeUnit)
+                                        TimeUnit, CRecordBatchReader)
 from pyarrow.lib cimport _Weakrefable
 
 
@@ -340,7 +340,7 @@ cdef extern from "parquet/api/reader.h" namespace "parquet" 
nogil:
         ArrowReaderProperties()
         void set_read_dictionary(int column_index, c_bool read_dict)
         c_bool read_dictionary()
-        void set_batch_size()
+        void set_batch_size(int64_t batch_size)
         int64_t batch_size()
 
     ArrowReaderProperties default_arrow_reader_properties()
@@ -407,6 +407,12 @@ cdef extern from "parquet/arrow/reader.h" namespace 
"parquet::arrow" nogil:
                               const vector[int]& column_indices,
                               shared_ptr[CTable]* out)
 
+        CStatus GetRecordBatchReader(const vector[int]& row_group_indices,
+                                     const vector[int]& column_indices,
+                                     unique_ptr[CRecordBatchReader]* out)
+        CStatus GetRecordBatchReader(const vector[int]& row_group_indices,
+                                     unique_ptr[CRecordBatchReader]* out)
+
         CStatus ReadTable(shared_ptr[CTable]* out)
         CStatus ReadTable(const vector[int]& column_indices,
                           shared_ptr[CTable]* out)
@@ -418,6 +424,8 @@ cdef extern from "parquet/arrow/reader.h" namespace 
"parquet::arrow" nogil:
 
         void set_use_threads(c_bool use_threads)
 
+        void set_batch_size(int64_t batch_size)
+
     cdef cppclass FileReaderBuilder:
         FileReaderBuilder()
         CStatus Open(const shared_ptr[CRandomAccessFile]& file,
diff --git a/python/pyarrow/_parquet.pyx b/python/pyarrow/_parquet.pyx
index 029df3e..53bbee6 100644
--- a/python/pyarrow/_parquet.pyx
+++ b/python/pyarrow/_parquet.pyx
@@ -36,6 +36,7 @@ from pyarrow.lib cimport (_Weakrefable, Buffer, Array, Schema,
                           pyarrow_wrap_schema,
                           pyarrow_wrap_table,
                           pyarrow_wrap_buffer,
+                          pyarrow_wrap_batch,
                           NativeFile, get_reader, get_writer)
 
 from pyarrow.lib import (ArrowException, NativeFile, BufferOutputStream,
@@ -1002,6 +1003,54 @@ cdef class ParquetReader(_Weakrefable):
     def set_use_threads(self, bint use_threads):
         self.reader.get().set_use_threads(use_threads)
 
+    def set_batch_size(self, int64_t batch_size):
+        self.reader.get().set_batch_size(batch_size)
+
+    def iter_batches(self, int64_t batch_size, row_groups, column_indices=None,
+                     bint use_threads=True):
+        cdef:
+            vector[int] c_row_groups
+            vector[int] c_column_indices
+            shared_ptr[CRecordBatch] record_batch
+            shared_ptr[TableBatchReader] batch_reader
+            unique_ptr[CRecordBatchReader] recordbatchreader
+
+        self.set_batch_size(batch_size)
+
+        if use_threads:
+            self.set_use_threads(use_threads)
+
+        for row_group in row_groups:
+            c_row_groups.push_back(row_group)
+
+        if column_indices is not None:
+            for index in column_indices:
+                c_column_indices.push_back(index)
+            with nogil:
+                check_status(
+                    self.reader.get().GetRecordBatchReader(
+                        c_row_groups, c_column_indices, &recordbatchreader
+                    )
+                )
+        else:
+            with nogil:
+                check_status(
+                    self.reader.get().GetRecordBatchReader(
+                        c_row_groups, &recordbatchreader
+                    )
+                )
+
+        while True:
+            with nogil:
+                check_status(
+                    recordbatchreader.get().ReadNext(&record_batch)
+                )
+
+            if record_batch.get() == NULL:
+                break
+
+            yield pyarrow_wrap_batch(record_batch)
+
     def read_row_group(self, int i, column_indices=None,
                        bint use_threads=True):
         return self.read_row_groups([i], column_indices, use_threads)
diff --git a/python/pyarrow/parquet.py b/python/pyarrow/parquet.py
index afaeb6f..9ad2998 100644
--- a/python/pyarrow/parquet.py
+++ b/python/pyarrow/parquet.py
@@ -319,6 +319,44 @@ class ParquetFile:
                                            column_indices=column_indices,
                                            use_threads=use_threads)
 
+    def iter_batches(self, batch_size=65536, row_groups=None, columns=None,
+                     use_threads=True, use_pandas_metadata=False):
+        """
+        Read streaming batches from a Parquet file
+
+        Parameters
+        ----------
+        batch_size: int, default 64K
+            Maximum number of records to yield per batch. Batches may be
+            smaller if there aren't enough rows in the file.
+        row_groups: list
+            Only these row groups will be read from the file.
+        columns: list
+            If not None, only these columns will be read from the file. A
+            column name may be a prefix of a nested field, e.g. 'a' will select
+            'a.b', 'a.c', and 'a.d.e'.
+        use_threads : boolean, default True
+            Perform multi-threaded column reads.
+        use_pandas_metadata : boolean, default False
+            If True and file has custom pandas schema metadata, ensure that
+            index columns are also loaded.
+
+        Returns
+        -------
+        iterator of pyarrow.RecordBatch
+            Contents of each batch as a record batch
+        """
+        if row_groups is None:
+            row_groups = range(0, self.metadata.num_row_groups)
+        column_indices = self._get_column_indices(
+            columns, use_pandas_metadata=use_pandas_metadata)
+
+        batches = self.reader.iter_batches(batch_size,
+                                           row_groups=row_groups,
+                                           column_indices=column_indices,
+                                           use_threads=use_threads)
+        return batches
+
     def read(self, columns=None, use_threads=True, use_pandas_metadata=False):
         """
         Read a Table from Parquet format,
diff --git a/python/pyarrow/tests/parquet/test_parquet_file.py 
b/python/pyarrow/tests/parquet/test_parquet_file.py
index d50a77b..85f81a3 100644
--- a/python/pyarrow/tests/parquet/test_parquet_file.py
+++ b/python/pyarrow/tests/parquet/test_parquet_file.py
@@ -186,3 +186,73 @@ def test_read_column_invalid_index():
     for index in (-1, 2):
         with pytest.raises((ValueError, IndexError)):
             f.reader.read_column(index)
+
+
[email protected]
[email protected]('batch_size', [300, 1000, 1300])
+def test_iter_batches_columns_reader(tempdir, batch_size):
+    total_size = 3000
+    chunk_size = 1000
+    # TODO: Add categorical support
+    df = alltypes_sample(size=total_size)
+
+    filename = tempdir / 'pandas_roundtrip.parquet'
+    arrow_table = pa.Table.from_pandas(df)
+    _write_table(arrow_table, filename, version="2.0",
+                 coerce_timestamps='ms', chunk_size=chunk_size)
+
+    file_ = pq.ParquetFile(filename)
+    for columns in [df.columns[:10], df.columns[10:]]:
+        batches = file_.iter_batches(batch_size=batch_size, columns=columns)
+        batch_starts = range(0, total_size+batch_size, batch_size)
+        for batch, start in zip(batches, batch_starts):
+            end = min(total_size, start + batch_size)
+            tm.assert_frame_equal(
+                batch.to_pandas(),
+                df.iloc[start:end, :].loc[:, columns].reset_index(drop=True)
+            )
+
+
[email protected]
[email protected]('chunk_size', [1000])
+def test_iter_batches_reader(tempdir, chunk_size):
+    df = alltypes_sample(size=10000, categorical=True)
+
+    filename = tempdir / 'pandas_roundtrip.parquet'
+    arrow_table = pa.Table.from_pandas(df)
+    assert arrow_table.schema.pandas_metadata is not None
+
+    _write_table(arrow_table, filename, version="2.0",
+                 coerce_timestamps='ms', chunk_size=chunk_size)
+
+    file_ = pq.ParquetFile(filename)
+
+    def get_all_batches(f):
+        for row_group in range(f.num_row_groups):
+            batches = f.iter_batches(
+                batch_size=900,
+                row_groups=[row_group],
+            )
+
+            for batch in batches:
+                yield batch
+
+    batches = list(get_all_batches(file_))
+    batch_no = 0
+
+    for i in range(file_.num_row_groups):
+        tm.assert_frame_equal(
+            batches[batch_no].to_pandas(),
+            file_.read_row_groups([i]).to_pandas().head(900)
+        )
+
+        batch_no += 1
+
+        tm.assert_frame_equal(
+            batches[batch_no].to_pandas().reset_index(drop=True),
+            file_.read_row_groups([i]).to_pandas().iloc[900:].reset_index(
+                drop=True
+            )
+        )
+
+        batch_no += 1

Reply via email to