This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new cc9668756 Cleanup record skipping logic and tests (#2158) (#2199)
cc9668756 is described below

commit cc9668756d37aef8f9d6e6f0484eae67e1e54e11
Author: Raphael Taylor-Davies <[email protected]>
AuthorDate: Thu Jul 28 10:57:18 2022 +0100

    Cleanup record skipping logic and tests (#2158) (#2199)
---
 parquet/src/arrow/array_reader/byte_array.rs       |   5 +-
 .../arrow/array_reader/byte_array_dictionary.rs    |   5 +-
 parquet/src/arrow/array_reader/mod.rs              |  49 +++--
 parquet/src/arrow/array_reader/null_array.rs       |   5 +-
 parquet/src/arrow/array_reader/primitive_array.rs  |   5 +-
 parquet/src/arrow/arrow_reader.rs                  | 230 ++++++++-------------
 6 files changed, 121 insertions(+), 178 deletions(-)

diff --git a/parquet/src/arrow/array_reader/byte_array.rs 
b/parquet/src/arrow/array_reader/byte_array.rs
index a29888f70..ec4188890 100644
--- a/parquet/src/arrow/array_reader/byte_array.rs
+++ b/parquet/src/arrow/array_reader/byte_array.rs
@@ -15,7 +15,7 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use crate::arrow::array_reader::{read_records, ArrayReader, set_column_reader};
+use crate::arrow::array_reader::{read_records, skip_records, ArrayReader};
 use crate::arrow::buffer::offset_buffer::OffsetBuffer;
 use crate::arrow::record_reader::buffer::ScalarValue;
 use crate::arrow::record_reader::GenericRecordReader;
@@ -120,8 +120,7 @@ impl<I: OffsetSizeTrait + ScalarValue> ArrayReader for 
ByteArrayReader<I> {
     }
 
     fn skip_records(&mut self, num_records: usize) -> Result<usize> {
-        set_column_reader(&mut self.record_reader, self.pages.as_mut())?;
-        self.record_reader.skip_records(num_records)
+        skip_records(&mut self.record_reader, self.pages.as_mut(), num_records)
     }
 
     fn get_def_levels(&self) -> Option<&[i16]> {
diff --git a/parquet/src/arrow/array_reader/byte_array_dictionary.rs 
b/parquet/src/arrow/array_reader/byte_array_dictionary.rs
index eba9e578f..51ef38d0d 100644
--- a/parquet/src/arrow/array_reader/byte_array_dictionary.rs
+++ b/parquet/src/arrow/array_reader/byte_array_dictionary.rs
@@ -25,7 +25,7 @@ use arrow::buffer::Buffer;
 use arrow::datatypes::{ArrowNativeType, DataType as ArrowType};
 
 use crate::arrow::array_reader::byte_array::{ByteArrayDecoder, 
ByteArrayDecoderPlain};
-use crate::arrow::array_reader::{read_records, ArrayReader, set_column_reader};
+use crate::arrow::array_reader::{read_records, ArrayReader, skip_records};
 use crate::arrow::buffer::{
     dictionary_buffer::DictionaryBuffer, offset_buffer::OffsetBuffer,
 };
@@ -181,8 +181,7 @@ where
     }
 
     fn skip_records(&mut self, num_records: usize) -> Result<usize> {
-        set_column_reader(&mut self.record_reader, self.pages.as_mut())?;
-        self.record_reader.skip_records(num_records)
+        skip_records(&mut self.record_reader, self.pages.as_mut(), num_records)
     }
 
     fn get_def_levels(&self) -> Option<&[i16]> {
diff --git a/parquet/src/arrow/array_reader/mod.rs 
b/parquet/src/arrow/array_reader/mod.rs
index a9d8cc0fa..8bdd6c071 100644
--- a/parquet/src/arrow/array_reader/mod.rs
+++ b/parquet/src/arrow/array_reader/mod.rs
@@ -113,7 +113,7 @@ impl RowGroupCollection for Arc<dyn FileReader> {
 
 /// Uses `record_reader` to read up to `batch_size` records from `pages`
 ///
-/// Returns the number of records read, which can be less than batch_size if
+/// Returns the number of records read, which can be less than `batch_size` if
 /// pages is exhausted.
 fn read_records<V, CV>(
     record_reader: &mut GenericRecordReader<V, CV>,
@@ -145,29 +145,36 @@ where
     Ok(records_read)
 }
 
-/// Uses `pages` to set up to `record_reader` 's `column_reader`
+/// Uses `record_reader` to skip up to `batch_size` records from`pages`
 ///
-/// If we skip records before all read operation,
-/// need set `column_reader` by `set_page_reader`
-/// for constructing `def_level_decoder` and `rep_level_decoder`.
-fn set_column_reader<V, CV>(
+/// Returns the number of records skipped, which can be less than `batch_size` 
if
+/// pages is exhausted
+fn skip_records<V, CV>(
     record_reader: &mut GenericRecordReader<V, CV>,
     pages: &mut dyn PageIterator,
-) -> Result<bool>
-where
-    V: ValuesBuffer + Default,
-    CV: ColumnValueDecoder<Slice = V::Slice>,
+    batch_size: usize,
+) -> Result<usize>
+    where
+        V: ValuesBuffer + Default,
+        CV: ColumnValueDecoder<Slice = V::Slice>,
 {
-    return if record_reader.column_reader().is_none() {
-        // If we skip records before all read operation
-        // we need set `column_reader` by `set_page_reader`
-        if let Some(page_reader) = pages.next() {
-            record_reader.set_page_reader(page_reader?)?;
-            Ok(true)
-        } else {
-            Ok(false)
+    let mut records_skipped = 0usize;
+    while records_skipped < batch_size {
+        let records_to_read = batch_size - records_skipped;
+
+        let records_skipped_once = 
record_reader.skip_records(records_to_read)?;
+        records_skipped += records_skipped_once;
+
+        // Record reader exhausted
+        if records_skipped_once < records_to_read {
+            if let Some(page_reader) = pages.next() {
+                // Read from new page reader (i.e. column chunk)
+                record_reader.set_page_reader(page_reader?)?;
+            } else {
+                // Page reader also exhausted
+                break;
+            }
         }
-    } else {
-        Ok(true)
-    };
+    }
+    Ok(records_skipped)
 }
diff --git a/parquet/src/arrow/array_reader/null_array.rs 
b/parquet/src/arrow/array_reader/null_array.rs
index a8c50b87f..63f73d41e 100644
--- a/parquet/src/arrow/array_reader/null_array.rs
+++ b/parquet/src/arrow/array_reader/null_array.rs
@@ -15,7 +15,7 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use crate::arrow::array_reader::{read_records, ArrayReader, set_column_reader};
+use crate::arrow::array_reader::{read_records, ArrayReader, skip_records};
 use crate::arrow::record_reader::buffer::ScalarValue;
 use crate::arrow::record_reader::RecordReader;
 use crate::column::page::PageIterator;
@@ -97,8 +97,7 @@ where
     }
 
     fn skip_records(&mut self, num_records: usize) -> Result<usize> {
-        set_column_reader(&mut self.record_reader, self.pages.as_mut())?;
-        self.record_reader.skip_records(num_records)
+        skip_records(&mut self.record_reader, self.pages.as_mut(), num_records)
     }
 
     fn get_def_levels(&self) -> Option<&[i16]> {
diff --git a/parquet/src/arrow/array_reader/primitive_array.rs 
b/parquet/src/arrow/array_reader/primitive_array.rs
index 700b12b0a..2a59f0326 100644
--- a/parquet/src/arrow/array_reader/primitive_array.rs
+++ b/parquet/src/arrow/array_reader/primitive_array.rs
@@ -15,7 +15,7 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use crate::arrow::array_reader::{read_records, set_column_reader, ArrayReader};
+use crate::arrow::array_reader::{read_records, skip_records, ArrayReader};
 use crate::arrow::record_reader::buffer::ScalarValue;
 use crate::arrow::record_reader::RecordReader;
 use crate::arrow::schema::parquet_to_arrow_field;
@@ -222,8 +222,7 @@ where
     }
 
     fn skip_records(&mut self, num_records: usize) -> Result<usize> {
-        set_column_reader(&mut self.record_reader, self.pages.as_mut())?;
-        self.record_reader.skip_records(num_records)
+        skip_records(&mut self.record_reader, self.pages.as_mut(), num_records)
     }
 
     fn get_def_levels(&self) -> Option<&[i16]> {
diff --git a/parquet/src/arrow/arrow_reader.rs 
b/parquet/src/arrow/arrow_reader.rs
index b64d1bfbf..19985818d 100644
--- a/parquet/src/arrow/arrow_reader.rs
+++ b/parquet/src/arrow/arrow_reader.rs
@@ -384,6 +384,7 @@ impl ParquetRecordBatchReader {
 mod tests {
     use bytes::Bytes;
     use std::cmp::min;
+    use std::collections::VecDeque;
     use std::convert::TryFrom;
     use std::fs::File;
     use std::io::Seek;
@@ -1624,154 +1625,105 @@ mod tests {
         test_row_group_batch(MIN_BATCH_SIZE - 1, MIN_BATCH_SIZE);
     }
 
-    #[test]
-    fn test_scan_row_with_selection() {
-        let testdata = arrow::util::test_util::parquet_test_data();
-        let path = format!("{}/alltypes_tiny_pages_plain.parquet", testdata);
-        let test_file = File::open(&path).unwrap();
+    /// Given a RecordBatch containing all the column data, return the 
expected batches given
+    /// a `batch_size` and `selection`
+    fn get_expected_batches(
+        column: &RecordBatch,
+        selection: &[RowSelection],
+        batch_size: usize,
+    ) -> Vec<RecordBatch> {
+        let mut expected_batches = vec![];
+
+        let mut selection: VecDeque<_> = selection.iter().cloned().collect();
+        let mut row_offset = 0;
+        let mut last_start = None;
+        while row_offset < column.num_rows() && !selection.is_empty() {
+            let mut batch_remaining = batch_size.min(column.num_rows() - 
row_offset);
+            while batch_remaining > 0 && !selection.is_empty() {
+                let (to_read, skip) = match selection.front_mut() {
+                    Some(selection) if selection.row_count > batch_remaining 
=> {
+                        selection.row_count -= batch_remaining;
+                        (batch_remaining, selection.skip)
+                    }
+                    Some(_) => {
+                        let select = selection.pop_front().unwrap();
+                        (select.row_count, select.skip)
+                    }
+                    None => break,
+                };
 
-        // total row count 7300
-        // 1. test selection len more than one page row count
-        let batch_size = 1000;
-        let expected_data = create_expect_batch(&test_file, batch_size);
-
-        let selections = create_test_selection(batch_size, 7300, false);
-        let skip_reader = create_skip_reader(&test_file, batch_size, 
selections);
-        let mut total_row_count = 0;
-        let mut index = 0;
-        for batch in skip_reader {
-            let batch = batch.unwrap();
-            assert_eq!(batch, expected_data.get(index).unwrap().clone());
-            index += 2;
-            let num = batch.num_rows();
-            assert!(num == batch_size || num == 300);
-            total_row_count += num;
-        }
-        assert_eq!(total_row_count, 4000);
+                batch_remaining -= to_read;
 
-        let selections = create_test_selection(batch_size, 7300, true);
-        let skip_reader = create_skip_reader(&test_file, batch_size, 
selections);
-        let mut total_row_count = 0;
-        let mut index = 1;
-        for batch in skip_reader {
-            let batch = batch.unwrap();
-            assert_eq!(batch, expected_data.get(index).unwrap().clone());
-            index += 2;
-            let num = batch.num_rows();
-            //the lase batch will be 300
-            assert!(num == batch_size || num == 300);
-            total_row_count += num;
+                match skip {
+                    true => {
+                        if let Some(last_start) = last_start.take() {
+                            expected_batches
+                                .push(column.slice(last_start, row_offset - 
last_start))
+                        }
+                        row_offset += to_read
+                    }
+                    false => {
+                        last_start.get_or_insert(row_offset);
+                        row_offset += to_read
+                    }
+                }
+            }
         }
-        assert_eq!(total_row_count, 3300);
 
-        // 2. test selection len less than one page row count
-        let batch_size = 20;
-        let expected_data = create_expect_batch(&test_file, batch_size);
-        let selections = create_test_selection(batch_size, 7300, false);
-
-        let skip_reader = create_skip_reader(&test_file, batch_size, 
selections);
-        let mut total_row_count = 0;
-        let mut index = 0;
-        for batch in skip_reader {
-            let batch = batch.unwrap();
-            assert_eq!(batch, expected_data.get(index).unwrap().clone());
-            index += 2;
-            let num = batch.num_rows();
-            assert_eq!(num, batch_size);
-            total_row_count += num;
+        if let Some(last_start) = last_start.take() {
+            expected_batches.push(column.slice(last_start, row_offset - 
last_start))
         }
-        assert_eq!(total_row_count, 3660);
 
-        let selections = create_test_selection(batch_size, 7300, true);
-        let skip_reader = create_skip_reader(&test_file, batch_size, 
selections);
-        let mut total_row_count = 0;
-        let mut index = 1;
-        for batch in skip_reader {
-            let batch = batch.unwrap();
-            assert_eq!(batch, expected_data.get(index).unwrap().clone());
-            index += 2;
-            let num = batch.num_rows();
-            assert_eq!(num, batch_size);
-            total_row_count += num;
+        // Sanity check, all batches except the final should be the batch size
+        for batch in &expected_batches[..expected_batches.len() - 1] {
+            assert_eq!(batch.num_rows(), batch_size);
         }
-        assert_eq!(total_row_count, 3640);
 
-        // 3. test selection_len less than batch_size
-        let batch_size = 20;
-        let selection_len = 5;
-        let expected_data_batch = create_expect_batch(&test_file, batch_size);
-        let expected_data_selection = create_expect_batch(&test_file, 
selection_len);
-        let selections = create_test_selection(selection_len, 7300, false);
-        let skip_reader = create_skip_reader(&test_file, batch_size, 
selections);
+        expected_batches
+    }
+
+    #[test]
+    fn test_scan_row_with_selection() {
+        let testdata = arrow::util::test_util::parquet_test_data();
+        let path = format!("{}/alltypes_tiny_pages_plain.parquet", testdata);
+        let test_file = File::open(&path).unwrap();
 
-        let mut total_row_count = 0;
+        let mut serial_arrow_reader =
+            
ParquetFileArrowReader::try_new(File::open(path).unwrap()).unwrap();
+        let mut serial_reader = 
serial_arrow_reader.get_record_reader(7300).unwrap();
+        let data = serial_reader.next().unwrap().unwrap();
 
-        for batch in skip_reader {
-            let batch = batch.unwrap();
-            let num = batch.num_rows();
-            assert!(num == batch_size || num == selection_len);
-            if num == batch_size {
-                assert_eq!(
-                    batch,
-                    expected_data_batch
-                        .get(total_row_count / batch_size)
-                        .unwrap()
-                        .clone()
-                );
-                total_row_count += batch_size;
-            } else if num == selection_len {
+        let do_test = |batch_size: usize, selection_len: usize| {
+            for skip_first in [false, true] {
+                let selections =
+                    create_test_selection(batch_size, data.num_rows(), 
skip_first);
+
+                let expected = get_expected_batches(&data, &selections, 
batch_size);
+                let skip_reader = create_skip_reader(&test_file, batch_size, 
selections);
                 assert_eq!(
-                    batch,
-                    expected_data_selection
-                        .get(total_row_count / selection_len)
-                        .unwrap()
-                        .clone()
+                    skip_reader.collect::<ArrowResult<Vec<_>>>().unwrap(),
+                    expected,
+                    "batch_size: {}, selection_len: {}, skip_first: {}",
+                    batch_size,
+                    selection_len,
+                    skip_first
                 );
-                total_row_count += selection_len;
             }
-            // add skip offset
-            total_row_count += selection_len;
-        }
+        };
+
+        // total row count 7300
+        // 1. test selection len more than one page row count
+        do_test(1000, 1000);
+
+        // 2. test selection len less than one page row count
+        do_test(20, 20);
+
+        // 3. test selection_len less than batch_size
+        do_test(20, 5);
 
         // 4. test selection_len more than batch_size
-        // If batch_size < selection_len will divide selection(50, read) ->
-        // selection(20, read), selection(20, read), selection(10, read)
-        let batch_size = 20;
-        let selection_len = 50;
-        let another_batch_size = 10;
-        let expected_data_batch = create_expect_batch(&test_file, batch_size);
-        let expected_data_batch2 = create_expect_batch(&test_file, 
another_batch_size);
-        let selections = create_test_selection(selection_len, 7300, false);
-        let skip_reader = create_skip_reader(&test_file, batch_size, 
selections);
-
-        let mut total_row_count = 0;
-
-        for batch in skip_reader {
-            let batch = batch.unwrap();
-            let num = batch.num_rows();
-            assert!(num == batch_size || num == another_batch_size);
-            if num == batch_size {
-                assert_eq!(
-                    batch,
-                    expected_data_batch
-                        .get(total_row_count / batch_size)
-                        .unwrap()
-                        .clone()
-                );
-                total_row_count += batch_size;
-            } else if num == another_batch_size {
-                assert_eq!(
-                    batch,
-                    expected_data_batch2
-                        .get(total_row_count / another_batch_size)
-                        .unwrap()
-                        .clone()
-                );
-                total_row_count += 10;
-                // add skip offset
-                total_row_count += selection_len;
-            }
-        }
+        // If batch_size < selection_len
+        do_test(20, 5);
 
         fn create_skip_reader(
             test_file: &File,
@@ -1812,17 +1764,5 @@ mod tests {
             }
             vec
         }
-
-        fn create_expect_batch(test_file: &File, batch_size: usize) -> 
Vec<RecordBatch> {
-            let mut serial_arrow_reader =
-                
ParquetFileArrowReader::try_new(test_file.try_clone().unwrap()).unwrap();
-            let serial_reader =
-                serial_arrow_reader.get_record_reader(batch_size).unwrap();
-            let mut expected_data = vec![];
-            for batch in serial_reader {
-                expected_data.push(batch.unwrap());
-            }
-            expected_data
-        }
     }
 }

Reply via email to