This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new ca1d85f74 Add explicit column mask construction in parquet: 
`ProjectionMask` (#1701) (#1716)
ca1d85f74 is described below

commit ca1d85f746099c79b43700496042ed567d95c6cc
Author: Raphael Taylor-Davies <[email protected]>
AuthorDate: Tue May 24 08:08:18 2022 +0100

    Add explicit column mask construction in parquet: `ProjectionMask` (#1701) 
(#1716)
    
    * Add explicit column mask construction (#1701)
    
    * Fix ParquetRecordBatchStream
    
    * Fix docs
    
    * Fix async_reader test
    
    * Review feedback
---
 parquet/src/arrow/array_reader/builder.rs    |  22 ++---
 parquet/src/arrow/array_reader/list_array.rs |   9 +-
 parquet/src/arrow/arrow_reader.rs            | 111 +++++++++------------
 parquet/src/arrow/async_reader.rs            |  54 +++++------
 parquet/src/arrow/mod.rs                     | 104 +++++++++++++++++++-
 parquet/src/arrow/schema.rs                  | 138 ++++++---------------------
 parquet/src/arrow/schema/complex.rs          |  27 ++----
 parquet/src/schema/types.rs                  |  37 +++----
 8 files changed, 237 insertions(+), 265 deletions(-)

diff --git a/parquet/src/arrow/array_reader/builder.rs 
b/parquet/src/arrow/array_reader/builder.rs
index 496af52ed..7b9adfc23 100644
--- a/parquet/src/arrow/array_reader/builder.rs
+++ b/parquet/src/arrow/array_reader/builder.rs
@@ -32,6 +32,7 @@ use crate::arrow::converter::{
     IntervalYearMonthArrayConverter, IntervalYearMonthConverter,
 };
 use crate::arrow::schema::{convert_schema, ParquetField, ParquetFieldType};
+use crate::arrow::ProjectionMask;
 use crate::basic::Type as PhysicalType;
 use crate::data_type::{
     BoolType, DoubleType, FixedLenByteArrayType, FloatType, Int32Type, 
Int64Type,
@@ -40,21 +41,15 @@ use crate::data_type::{
 use crate::errors::Result;
 use crate::schema::types::{ColumnDescriptor, ColumnPath, SchemaDescPtr, Type};
 
-/// Create array reader from parquet schema, column indices, and parquet file 
reader.
-pub fn build_array_reader<T>(
+/// Create array reader from parquet schema, projection mask, and parquet file 
reader.
+pub fn build_array_reader(
     parquet_schema: SchemaDescPtr,
     arrow_schema: SchemaRef,
-    column_indices: T,
+    mask: ProjectionMask,
     row_groups: Box<dyn RowGroupCollection>,
-) -> Result<Box<dyn ArrayReader>>
-where
-    T: IntoIterator<Item = usize>,
-{
-    let field = convert_schema(
-        parquet_schema.as_ref(),
-        column_indices,
-        Some(arrow_schema.as_ref()),
-    )?;
+) -> Result<Box<dyn ArrayReader>> {
+    let field =
+        convert_schema(parquet_schema.as_ref(), mask, 
Some(arrow_schema.as_ref()))?;
 
     match &field {
         Some(field) => build_reader(field, row_groups.as_ref()),
@@ -346,6 +341,7 @@ mod tests {
             Arc::new(SerializedFileReader::new(file).unwrap());
 
         let file_metadata = file_reader.metadata().file_metadata();
+        let mask = ProjectionMask::leaves(file_metadata.schema_descr(), [0]);
         let arrow_schema = parquet_to_arrow_schema(
             file_metadata.schema_descr(),
             file_metadata.key_value_metadata(),
@@ -355,7 +351,7 @@ mod tests {
         let array_reader = build_array_reader(
             file_reader.metadata().file_metadata().schema_descr_ptr(),
             Arc::new(arrow_schema),
-            vec![0usize].into_iter(),
+            mask,
             Box::new(file_reader),
         )
         .unwrap();
diff --git a/parquet/src/arrow/array_reader/list_array.rs 
b/parquet/src/arrow/array_reader/list_array.rs
index 31796d79b..2d199f69e 100644
--- a/parquet/src/arrow/array_reader/list_array.rs
+++ b/parquet/src/arrow/array_reader/list_array.rs
@@ -246,7 +246,7 @@ mod tests {
     use crate::arrow::array_reader::build_array_reader;
     use crate::arrow::array_reader::list_array::ListArrayReader;
     use crate::arrow::array_reader::test_util::InMemoryArrayReader;
-    use crate::arrow::{parquet_to_arrow_schema, ArrowWriter};
+    use crate::arrow::{parquet_to_arrow_schema, ArrowWriter, ProjectionMask};
     use crate::file::properties::WriterProperties;
     use crate::file::reader::{FileReader, SerializedFileReader};
     use crate::schema::parser::parse_message_type;
@@ -582,10 +582,13 @@ mod tests {
         )
         .unwrap();
 
+        let schema = file_metadata.schema_descr_ptr();
+        let mask = ProjectionMask::leaves(&schema, vec![0]);
+
         let mut array_reader = build_array_reader(
-            file_reader.metadata().file_metadata().schema_descr_ptr(),
+            schema,
             Arc::new(arrow_schema),
-            vec![0usize].into_iter(),
+            mask,
             Box::new(file_reader),
         )
         .unwrap();
diff --git a/parquet/src/arrow/arrow_reader.rs 
b/parquet/src/arrow/arrow_reader.rs
index 0c4ded90b..e3a1d1233 100644
--- a/parquet/src/arrow/arrow_reader.rs
+++ b/parquet/src/arrow/arrow_reader.rs
@@ -27,9 +27,8 @@ use arrow::{array::StructArray, error::ArrowError};
 
 use crate::arrow::array_reader::{build_array_reader, ArrayReader};
 use crate::arrow::schema::parquet_to_arrow_schema;
-use crate::arrow::schema::{
-    parquet_to_arrow_schema_by_columns, 
parquet_to_arrow_schema_by_root_columns,
-};
+use crate::arrow::schema::parquet_to_arrow_schema_by_columns;
+use crate::arrow::ProjectionMask;
 use crate::errors::Result;
 use crate::file::metadata::{KeyValue, ParquetMetaData};
 use crate::file::reader::FileReader;
@@ -44,15 +43,8 @@ pub trait ArrowReader {
     fn get_schema(&mut self) -> Result<Schema>;
 
     /// Read parquet schema and convert it into arrow schema.
-    /// This schema only includes columns identified by `column_indices`.
-    /// To select leaf columns (i.e. `a.b.c` instead of `a`), set 
`leaf_columns = true`
-    fn get_schema_by_columns<T>(
-        &mut self,
-        column_indices: T,
-        leaf_columns: bool,
-    ) -> Result<Schema>
-    where
-        T: IntoIterator<Item = usize>;
+    /// This schema only includes columns identified by `mask`.
+    fn get_schema_by_columns(&mut self, mask: ProjectionMask) -> 
Result<Schema>;
 
     /// Returns record batch reader from whole parquet file.
     ///
@@ -64,19 +56,17 @@ pub trait ArrowReader {
     fn get_record_reader(&mut self, batch_size: usize) -> 
Result<Self::RecordReader>;
 
     /// Returns record batch reader whose record batch contains columns 
identified by
-    /// `column_indices`.
+    /// `mask`.
     ///
     /// # Arguments
     ///
-    /// `column_indices`: The columns that should be included in record 
batches.
+    /// `mask`: The columns that should be included in record batches.
     /// `batch_size`: Please refer to `get_record_reader`.
-    fn get_record_reader_by_columns<T>(
+    fn get_record_reader_by_columns(
         &mut self,
-        column_indices: T,
+        mask: ProjectionMask,
         batch_size: usize,
-    ) -> Result<Self::RecordReader>
-    where
-        T: IntoIterator<Item = usize>;
+    ) -> Result<Self::RecordReader>;
 }
 
 #[derive(Debug, Clone, Default)]
@@ -118,59 +108,34 @@ impl ArrowReader for ParquetFileArrowReader {
         parquet_to_arrow_schema(file_metadata.schema_descr(), 
self.get_kv_metadata())
     }
 
-    fn get_schema_by_columns<T>(
-        &mut self,
-        column_indices: T,
-        leaf_columns: bool,
-    ) -> Result<Schema>
-    where
-        T: IntoIterator<Item = usize>,
-    {
+    fn get_schema_by_columns(&mut self, mask: ProjectionMask) -> 
Result<Schema> {
         let file_metadata = self.file_reader.metadata().file_metadata();
-        if leaf_columns {
-            parquet_to_arrow_schema_by_columns(
-                file_metadata.schema_descr(),
-                column_indices,
-                self.get_kv_metadata(),
-            )
-        } else {
-            parquet_to_arrow_schema_by_root_columns(
-                file_metadata.schema_descr(),
-                column_indices,
-                self.get_kv_metadata(),
-            )
-        }
+        parquet_to_arrow_schema_by_columns(
+            file_metadata.schema_descr(),
+            mask,
+            self.get_kv_metadata(),
+        )
     }
 
     fn get_record_reader(
         &mut self,
         batch_size: usize,
     ) -> Result<ParquetRecordBatchReader> {
-        let column_indices = 0..self
-            .file_reader
-            .metadata()
-            .file_metadata()
-            .schema_descr()
-            .num_columns();
-
-        self.get_record_reader_by_columns(column_indices, batch_size)
+        self.get_record_reader_by_columns(ProjectionMask::all(), batch_size)
     }
 
-    fn get_record_reader_by_columns<T>(
+    fn get_record_reader_by_columns(
         &mut self,
-        column_indices: T,
+        mask: ProjectionMask,
         batch_size: usize,
-    ) -> Result<ParquetRecordBatchReader>
-    where
-        T: IntoIterator<Item = usize>,
-    {
+    ) -> Result<ParquetRecordBatchReader> {
         let array_reader = build_array_reader(
             self.file_reader
                 .metadata()
                 .file_metadata()
                 .schema_descr_ptr(),
             Arc::new(self.get_schema()?),
-            column_indices,
+            mask,
             Box::new(self.file_reader.clone()),
         )?;
 
@@ -296,7 +261,7 @@ mod tests {
         IntervalDayTimeArrayConverter, LargeUtf8ArrayConverter, 
Utf8ArrayConverter,
     };
     use crate::arrow::schema::add_encoded_arrow_schema_to_metadata;
-    use crate::arrow::ArrowWriter;
+    use crate::arrow::{ArrowWriter, ProjectionMask};
     use crate::basic::{ConvertedType, Encoding, Repetition, Type as 
PhysicalType};
     use crate::column::writer::get_typed_column_writer_mut;
     use crate::data_type::{
@@ -351,12 +316,14 @@ mod tests {
         let parquet_file_reader =
             get_test_reader("parquet/generated_simple_numerics/blogs.parquet");
 
-        let max_len = 
parquet_file_reader.metadata().file_metadata().num_rows() as usize;
+        let file_metadata = parquet_file_reader.metadata().file_metadata();
+        let max_len = file_metadata.num_rows() as usize;
 
+        let mask = ProjectionMask::leaves(file_metadata.schema_descr(), [2]);
         let mut arrow_reader = 
ParquetFileArrowReader::new(parquet_file_reader);
 
         let mut record_batch_reader = arrow_reader
-            .get_record_reader_by_columns(vec![2], 60)
+            .get_record_reader_by_columns(mask, 60)
             .expect("Failed to read into array!");
 
         // Verify that the schema was correctly parsed
@@ -1040,8 +1007,11 @@ mod tests {
         // (see: ARROW-11452)
         let testdata = arrow::util::test_util::parquet_test_data();
         let path = format!("{}/nested_structs.rust.parquet", testdata);
-        let parquet_file_reader =
-            
SerializedFileReader::try_from(File::open(&path).unwrap()).unwrap();
+        let file = File::open(&path).unwrap();
+        let parquet_file_reader = 
SerializedFileReader::try_from(file).unwrap();
+        let file_metadata = parquet_file_reader.metadata().file_metadata();
+        let schema = file_metadata.schema_descr_ptr();
+
         let mut arrow_reader = 
ParquetFileArrowReader::new(Arc::new(parquet_file_reader));
         let record_batch_reader = arrow_reader
             .get_record_reader(60)
@@ -1051,12 +1021,11 @@ mod tests {
             batch.unwrap();
         }
 
+        let mask = ProjectionMask::leaves(&schema, [3, 8, 10]);
         let projected_reader = arrow_reader
-            .get_record_reader_by_columns(vec![3, 8, 10], 60)
-            .unwrap();
-        let projected_schema = arrow_reader
-            .get_schema_by_columns(vec![3, 8, 10], true)
+            .get_record_reader_by_columns(mask.clone(), 60)
             .unwrap();
+        let projected_schema = 
arrow_reader.get_schema_by_columns(mask).unwrap();
 
         let expected_schema = Schema::new(vec![
             Field::new(
@@ -1139,8 +1108,11 @@ mod tests {
         }
 
         let file_reader = Arc::new(SerializedFileReader::new(file).unwrap());
+        let file_metadata = file_reader.metadata().file_metadata();
+        let mask = ProjectionMask::leaves(file_metadata.schema_descr(), [0]);
+
         let mut batch = ParquetFileArrowReader::new(file_reader);
-        let reader = batch.get_record_reader_by_columns(vec![0], 
1024).unwrap();
+        let reader = batch.get_record_reader_by_columns(mask, 1024).unwrap();
 
         let expected_schema = arrow::datatypes::Schema::new(vec![Field::new(
             "group",
@@ -1178,7 +1150,7 @@ mod tests {
         let mut arrow_reader = 
ParquetFileArrowReader::new(Arc::new(file_reader));
 
         let mut record_batch_reader = arrow_reader
-            .get_record_reader_by_columns(vec![0], 10)
+            .get_record_reader_by_columns(ProjectionMask::all(), 10)
             .unwrap();
 
         let error = record_batch_reader.next().unwrap().unwrap_err();
@@ -1414,10 +1386,13 @@ mod tests {
         let path = format!("{}/alltypes_plain.parquet", testdata);
         let file = File::open(&path).unwrap();
         let reader = SerializedFileReader::try_from(file).unwrap();
-        let expected_rows = reader.metadata().file_metadata().num_rows() as 
usize;
+        let file_metadata = reader.metadata().file_metadata();
+        let expected_rows = file_metadata.num_rows() as usize;
+        let schema = file_metadata.schema_descr_ptr();
 
         let mut arrow_reader = ParquetFileArrowReader::new(Arc::new(reader));
-        let batch_reader = arrow_reader.get_record_reader_by_columns([], 
2).unwrap();
+        let mask = ProjectionMask::leaves(&schema, []);
+        let batch_reader = arrow_reader.get_record_reader_by_columns(mask, 
2).unwrap();
 
         let mut total_rows = 0;
         for maybe_batch in batch_reader {
diff --git a/parquet/src/arrow/async_reader.rs 
b/parquet/src/arrow/async_reader.rs
index 7bf3ebfa9..5cd091184 100644
--- a/parquet/src/arrow/async_reader.rs
+++ b/parquet/src/arrow/async_reader.rs
@@ -27,7 +27,7 @@
 //! use futures::TryStreamExt;
 //! use tokio::fs::File;
 //!
-//! use parquet::arrow::ParquetRecordBatchStreamBuilder;
+//! use parquet::arrow::{ParquetRecordBatchStreamBuilder, ProjectionMask};
 //!
 //! # fn assert_batches_eq(batches: &[RecordBatch], expected_lines: &[&str]) {
 //! #     let formatted = pretty_format_batches(batches).unwrap().to_string();
@@ -41,16 +41,17 @@
 //!
 //! let testdata = arrow::util::test_util::parquet_test_data();
 //! let path = format!("{}/alltypes_plain.parquet", testdata);
-//! let file = tokio::fs::File::open(path).await.unwrap();
+//! let file = File::open(path).await.unwrap();
 //!
 //! let builder = ParquetRecordBatchStreamBuilder::new(file)
 //!     .await
 //!     .unwrap()
-//!     .with_projection(vec![1, 2, 6])
 //!     .with_batch_size(3);
 //!
-//! let stream = builder.build().unwrap();
+//! let file_metadata = builder.metadata().file_metadata();
+//! let mask = ProjectionMask::roots(file_metadata.schema_descr(), [1, 2, 6]);
 //!
+//! let stream = builder.with_projection(mask).build().unwrap();
 //! let results = stream.try_collect::<Vec<_>>().await.unwrap();
 //! assert_eq!(results.len(), 3);
 //!
@@ -92,6 +93,7 @@ use arrow::record_batch::RecordBatch;
 use crate::arrow::array_reader::{build_array_reader, RowGroupCollection};
 use crate::arrow::arrow_reader::ParquetRecordBatchReader;
 use crate::arrow::schema::parquet_to_arrow_schema;
+use crate::arrow::ProjectionMask;
 use crate::basic::Compression;
 use crate::column::page::{PageIterator, PageReader};
 use crate::errors::{ParquetError, Result};
@@ -119,7 +121,7 @@ pub struct ParquetRecordBatchStreamBuilder<T> {
 
     row_groups: Option<Vec<usize>>,
 
-    projection: Option<Vec<usize>>,
+    projection: ProjectionMask,
 }
 
 impl<T: AsyncRead + AsyncSeek + Unpin> ParquetRecordBatchStreamBuilder<T> {
@@ -138,7 +140,7 @@ impl<T: AsyncRead + AsyncSeek + Unpin> 
ParquetRecordBatchStreamBuilder<T> {
             schema,
             batch_size: 1024,
             row_groups: None,
-            projection: None,
+            projection: ProjectionMask::all(),
         })
     }
 
@@ -166,32 +168,17 @@ impl<T: AsyncRead + AsyncSeek + Unpin> 
ParquetRecordBatchStreamBuilder<T> {
     }
 
     /// Only read data from the provided column indexes
-    pub fn with_projection(self, projection: Vec<usize>) -> Self {
+    pub fn with_projection(self, mask: ProjectionMask) -> Self {
         Self {
-            projection: Some(projection),
+            projection: mask,
             ..self
         }
     }
 
     /// Build a new [`ParquetRecordBatchStream`]
     pub fn build(self) -> Result<ParquetRecordBatchStream<T>> {
-        let num_columns = self.schema.fields().len();
         let num_row_groups = self.metadata.row_groups().len();
 
-        let columns = match self.projection {
-            Some(projection) => {
-                if let Some(col) = projection.iter().find(|x| **x >= 
num_columns) {
-                    return Err(general_err!(
-                        "column projection {} outside bounds of schema 0..{}",
-                        col,
-                        num_columns
-                    ));
-                }
-                projection
-            }
-            None => (0..num_columns).collect::<Vec<_>>(),
-        };
-
         let row_groups = match self.row_groups {
             Some(row_groups) => {
                 if let Some(col) = row_groups.iter().find(|x| **x >= 
num_row_groups) {
@@ -208,7 +195,7 @@ impl<T: AsyncRead + AsyncSeek + Unpin> 
ParquetRecordBatchStreamBuilder<T> {
 
         Ok(ParquetRecordBatchStream {
             row_groups,
-            columns: columns.into(),
+            projection: self.projection,
             batch_size: self.batch_size,
             metadata: self.metadata,
             schema: self.schema,
@@ -248,7 +235,7 @@ pub struct ParquetRecordBatchStream<T> {
 
     batch_size: usize,
 
-    columns: Arc<[usize]>,
+    projection: ProjectionMask,
 
     row_groups: VecDeque<usize>,
 
@@ -264,7 +251,7 @@ impl<T> std::fmt::Debug for ParquetRecordBatchStream<T> {
             .field("metadata", &self.metadata)
             .field("schema", &self.schema)
             .field("batch_size", &self.batch_size)
-            .field("columns", &self.columns)
+            .field("projection", &self.projection)
             .field("state", &self.state)
             .finish()
     }
@@ -315,16 +302,19 @@ impl<T: AsyncRead + AsyncSeek + Unpin + Send + 'static> 
Stream
                         }
                     };
 
-                    let columns = Arc::clone(&self.columns);
-
+                    let projection = self.projection.clone();
                     self.state = StreamState::Reading(
                         async move {
                             let row_group_metadata = 
metadata.row_group(row_group_idx);
                             let mut column_chunks =
                                 vec![None; row_group_metadata.columns().len()];
 
-                            for column_idx in columns.iter() {
-                                let column = 
row_group_metadata.column(*column_idx);
+                            for (idx, chunk) in 
column_chunks.iter_mut().enumerate() {
+                                if !projection.leaf_included(idx) {
+                                    continue;
+                                }
+
+                                let column = row_group_metadata.column(idx);
                                 let (start, length) = column.byte_range();
                                 let end = start + length;
 
@@ -333,7 +323,7 @@ impl<T: AsyncRead + AsyncSeek + Unpin + Send + 'static> 
Stream
                                 let mut buffer = vec![0_u8; (end - start) as 
usize];
                                 input.read_exact(buffer.as_mut_slice()).await?;
 
-                                column_chunks[*column_idx] = 
Some(InMemoryColumnChunk {
+                                *chunk = Some(InMemoryColumnChunk {
                                     num_values: column.num_values(),
                                     compression: column.compression(),
                                     physical_type: column.column_type(),
@@ -373,7 +363,7 @@ impl<T: AsyncRead + AsyncSeek + Unpin + Send + 'static> 
Stream
                     let array_reader = build_array_reader(
                         parquet_schema,
                         self.schema.clone(),
-                        self.columns.iter().cloned(),
+                        self.projection.clone(),
                         row_group,
                     )?;
 
diff --git a/parquet/src/arrow/mod.rs b/parquet/src/arrow/mod.rs
index ea4760c03..5a5135cd3 100644
--- a/parquet/src/arrow/mod.rs
+++ b/parquet/src/arrow/mod.rs
@@ -67,8 +67,8 @@
 //!
 //! ```rust
 //! use arrow::record_batch::RecordBatchReader;
-//! use parquet::file::reader::SerializedFileReader;
-//! use parquet::arrow::{ParquetFileArrowReader, ArrowReader};
+//! use parquet::file::reader::{FileReader, SerializedFileReader};
+//! use parquet::arrow::{ParquetFileArrowReader, ArrowReader, ProjectionMask};
 //! use std::sync::Arc;
 //! use std::fs::File;
 //!
@@ -97,13 +97,20 @@
 //!
 //! let file = File::open("data.parquet").unwrap();
 //! let file_reader = SerializedFileReader::new(file).unwrap();
+//!
+//! let file_metadata = file_reader.metadata().file_metadata();
+//! let mask = ProjectionMask::leaves(file_metadata.schema_descr(), [0]);
+//!
 //! let mut arrow_reader = ParquetFileArrowReader::new(Arc::new(file_reader));
 //!
 //! println!("Converted arrow schema is: {}", 
arrow_reader.get_schema().unwrap());
 //! println!("Arrow schema after projection is: {}",
-//!    arrow_reader.get_schema_by_columns(vec![0], true).unwrap());
+//! arrow_reader.get_schema_by_columns(mask.clone()).unwrap());
+//!
+//! let mut unprojected = arrow_reader.get_record_reader(2048).unwrap();
+//! println!("Unprojected reader schema: {}", unprojected.schema());
 //!
-//! let mut record_batch_reader = 
arrow_reader.get_record_reader(2048).unwrap();
+//! let mut record_batch_reader = 
arrow_reader.get_record_reader_by_columns(mask, 2048).unwrap();
 //!
 //! for maybe_record_batch in record_batch_reader {
 //!    let record_batch = maybe_record_batch.unwrap();
@@ -133,11 +140,98 @@ pub use self::arrow_reader::ParquetFileArrowReader;
 pub use self::arrow_writer::ArrowWriter;
 #[cfg(feature = "async")]
 pub use self::async_reader::ParquetRecordBatchStreamBuilder;
+use crate::schema::types::SchemaDescriptor;
 
 pub use self::schema::{
     arrow_to_parquet_schema, parquet_to_arrow_schema, 
parquet_to_arrow_schema_by_columns,
-    parquet_to_arrow_schema_by_root_columns,
 };
 
 /// Schema metadata key used to store serialized Arrow IPC schema
 pub const ARROW_SCHEMA_META_KEY: &str = "ARROW:schema";
+
+/// A [`ProjectionMask`] identifies a set of columns within a potentially 
nested schema to project
+///
+/// In particular, a [`ProjectionMask`] can be constructed from a list of leaf 
column indices
+/// or root column indices where:
+///
+/// * Root columns are the direct children of the root schema, enumerated in 
order
+/// * Leaf columns are the child-less leaves of the schema as enumerated by a 
depth-first search
+///
+/// For example, the schema
+///
+/// ```ignore
+/// message schema {
+///   REQUIRED boolean         leaf_1;
+///   REQUIRED GROUP group {
+///     OPTIONAL int32 leaf_2;
+///     OPTIONAL int64 leaf_3;
+///   }
+/// }
+/// ```
+///
+/// Has roots `["leaf_1", "group"]` and leaves `["leaf_1", "leaf_2", "leaf_3"]`
+///
+/// For non-nested schemas, i.e. those containing only primitive columns, the 
root
+/// and leaves are the same
+///
+#[derive(Debug, Clone)]
+pub struct ProjectionMask {
+    /// If present a leaf column should be included if the value at
+    /// the corresponding index is true
+    ///
+    /// If `None`, include all columns
+    mask: Option<Vec<bool>>,
+}
+
+impl ProjectionMask {
+    /// Create a [`ProjectionMask`] which selects all columns
+    pub fn all() -> Self {
+        Self { mask: None }
+    }
+
+    /// Create a [`ProjectionMask`] which selects only the specified leaf 
columns
+    ///
+    /// Note: repeated or out of order indices will not impact the final mask
+    ///
+    /// i.e. `[0, 1, 2]` will construct the same mask as `[1, 0, 0, 2]`
+    pub fn leaves(
+        schema: &SchemaDescriptor,
+        indices: impl IntoIterator<Item = usize>,
+    ) -> Self {
+        let mut mask = vec![false; schema.num_columns()];
+        for leaf_idx in indices {
+            mask[leaf_idx] = true;
+        }
+        Self { mask: Some(mask) }
+    }
+
+    /// Create a [`ProjectionMask`] which selects only the specified root 
columns
+    ///
+    /// Note: repeated or out of order indices will not impact the final mask
+    ///
+    /// i.e. `[0, 1, 2]` will construct the same mask as `[1, 0, 0, 2]`
+    pub fn roots(
+        schema: &SchemaDescriptor,
+        indices: impl IntoIterator<Item = usize>,
+    ) -> Self {
+        let num_root_columns = schema.root_schema().get_fields().len();
+        let mut root_mask = vec![false; num_root_columns];
+        for root_idx in indices {
+            root_mask[root_idx] = true;
+        }
+
+        let mask = (0..schema.num_columns())
+            .map(|leaf_idx| {
+                let root_idx = schema.get_column_root_idx(leaf_idx);
+                root_mask[root_idx]
+            })
+            .collect();
+
+        Self { mask: Some(mask) }
+    }
+
+    /// Returns true if the leaf column `leaf_idx` is included by the mask
+    pub fn leaf_included(&self, leaf_idx: usize) -> bool {
+        self.mask.as_ref().map(|m| m[leaf_idx]).unwrap_or(true)
+    }
+}
diff --git a/parquet/src/arrow/schema.rs b/parquet/src/arrow/schema.rs
index 07c50d11c..820aa7e7a 100644
--- a/parquet/src/arrow/schema.rs
+++ b/parquet/src/arrow/schema.rs
@@ -33,13 +33,14 @@ use crate::basic::{
     ConvertedType, LogicalType, Repetition, TimeUnit as ParquetTimeUnit,
     Type as PhysicalType,
 };
-use crate::errors::{ParquetError::ArrowError, Result};
+use crate::errors::{ParquetError, Result};
 use crate::file::{metadata::KeyValue, properties::WriterProperties};
 use crate::schema::types::{ColumnDescriptor, SchemaDescriptor, Type, TypePtr};
 
 mod complex;
 mod primitive;
 
+use crate::arrow::ProjectionMask;
 pub(crate) use complex::{convert_schema, ParquetField, ParquetFieldType};
 
 /// Convert Parquet schema to Arrow schema including optional metadata.
@@ -51,74 +52,18 @@ pub fn parquet_to_arrow_schema(
 ) -> Result<Schema> {
     parquet_to_arrow_schema_by_columns(
         parquet_schema,
-        0..parquet_schema.columns().len(),
+        ProjectionMask::all(),
         key_value_metadata,
     )
 }
 
-/// Convert parquet schema to arrow schema including optional metadata,
-/// only preserving some root columns.
-/// This is useful if we have columns `a.b`, `a.c.e` and `a.d`,
-/// and want `a` with all its child fields
-pub fn parquet_to_arrow_schema_by_root_columns<T>(
-    parquet_schema: &SchemaDescriptor,
-    column_indices: T,
-    key_value_metadata: Option<&Vec<KeyValue>>,
-) -> Result<Schema>
-where
-    T: IntoIterator<Item = usize>,
-{
-    // Reconstruct the index ranges of the parent columns
-    // An Arrow struct gets represented by 1+ columns based on how many child 
fields the
-    // struct has. This means that getting fields 1 and 2 might return the 
struct twice,
-    // if field 1 is the struct having say 3 fields, and field 2 is a 
primitive.
-    //
-    // The below gets the parent columns, and counts the number of child 
fields in each parent,
-    // such that we would end up with:
-    // - field 1 - columns: [0, 1, 2]
-    // - field 2 - columns: [3]
-    let mut parent_columns = vec![];
-    let mut curr_name = "";
-    let mut prev_name = "";
-    let mut indices = vec![];
-    (0..(parquet_schema.num_columns())).for_each(|i| {
-        let p_type = parquet_schema.get_column_root(i);
-        curr_name = p_type.get_basic_info().name();
-        if prev_name.is_empty() {
-            // first index
-            indices.push(i);
-            prev_name = curr_name;
-        } else if curr_name != prev_name {
-            prev_name = curr_name;
-            parent_columns.push((curr_name.to_string(), indices.clone()));
-            indices = vec![i];
-        } else {
-            indices.push(i);
-        }
-    });
-    // push the last column if indices has values
-    if !indices.is_empty() {
-        parent_columns.push((curr_name.to_string(), indices));
-    }
-
-    // gather the required leaf columns
-    let leaf_columns = column_indices
-        .into_iter()
-        .flat_map(|i| parent_columns[i].1.clone());
-
-    parquet_to_arrow_schema_by_columns(parquet_schema, leaf_columns, 
key_value_metadata)
-}
-
 /// Convert parquet schema to arrow schema including optional metadata,
 /// only preserving some leaf columns.
-pub fn parquet_to_arrow_schema_by_columns<T>(
+pub fn parquet_to_arrow_schema_by_columns(
     parquet_schema: &SchemaDescriptor,
-    column_indices: T,
+    mask: ProjectionMask,
     key_value_metadata: Option<&Vec<KeyValue>>,
-) -> Result<Schema>
-where
-    T: IntoIterator<Item = usize>,
-{
+) -> Result<Schema> {
     let mut metadata = 
parse_key_value_metadata(key_value_metadata).unwrap_or_default();
     let maybe_schema = metadata
         .remove(super::ARROW_SCHEMA_META_KEY)
@@ -132,7 +77,7 @@ where
         });
     }
 
-    match convert_schema(parquet_schema, column_indices, 
maybe_schema.as_ref())? {
+    match convert_schema(parquet_schema, mask, maybe_schema.as_ref())? {
         Some(field) => match field.arrow_type {
             DataType::Struct(fields) => Ok(Schema::new_with_metadata(fields, 
metadata)),
             _ => unreachable!(),
@@ -155,24 +100,24 @@ fn get_arrow_schema_from_metadata(encoded_meta: &str) -> 
Result<Schema> {
                 Ok(message) => message
                     .header_as_schema()
                     .map(arrow::ipc::convert::fb_to_schema)
-                    .ok_or(ArrowError("the message is not Arrow 
Schema".to_string())),
+                    .ok_or(arrow_err!("the message is not Arrow Schema")),
                 Err(err) => {
                     // The flatbuffers implementation returns an error on 
verification error.
-                    Err(ArrowError(format!(
+                    Err(arrow_err!(
                         "Unable to get root as message stored in {}: {:?}",
                         super::ARROW_SCHEMA_META_KEY,
                         err
-                    )))
+                    ))
                 }
             }
         }
         Err(err) => {
             // The C++ implementation returns an error if the schema can't be 
parsed.
-            Err(ArrowError(format!(
+            Err(arrow_err!(
                 "Unable to decode the encoded schema stored in {}, {:?}",
                 super::ARROW_SCHEMA_META_KEY,
                 err
-            )))
+            ))
         }
     }
 }
@@ -342,7 +287,7 @@ fn arrow_to_parquet_type(field: &Field) -> Result<Type> {
             }))
             .with_repetition(repetition)
             .build(),
-        DataType::Float16 => Err(ArrowError("Float16 arrays not 
supported".to_string())),
+        DataType::Float16 => Err(arrow_err!("Float16 arrays not supported")),
         DataType::Float32 => Type::primitive_type_builder(name, 
PhysicalType::FLOAT)
             .with_repetition(repetition)
             .build(),
@@ -411,9 +356,9 @@ fn arrow_to_parquet_type(field: &Field) -> Result<Type> {
             }))
             .with_repetition(repetition)
             .build(),
-        DataType::Duration(_) => Err(ArrowError(
-            "Converting Duration to parquet not supported".to_string(),
-        )),
+        DataType::Duration(_) => {
+            Err(arrow_err!("Converting Duration to parquet not supported",))
+        }
         DataType::Interval(_) => {
             Type::primitive_type_builder(name, 
PhysicalType::FIXED_LEN_BYTE_ARRAY)
                 .with_converted_type(ConvertedType::INTERVAL)
@@ -477,9 +422,9 @@ fn arrow_to_parquet_type(field: &Field) -> Result<Type> {
         }
         DataType::Struct(fields) => {
             if fields.is_empty() {
-                return Err(ArrowError(
-                    "Parquet does not support writing empty 
structs".to_string(),
-                ));
+                return Err(
+                    arrow_err!("Parquet does not support writing empty 
structs",),
+                );
             }
             // recursively convert children to types/nodes
             let fields: Result<Vec<TypePtr>> = fields
@@ -515,8 +460,8 @@ fn arrow_to_parquet_type(field: &Field) -> Result<Type> {
                     .with_repetition(repetition)
                     .build()
             } else {
-                Err(ArrowError(
-                    "DataType::Map should contain a struct field 
child".to_string(),
+                Err(arrow_err!(
+                    "DataType::Map should contain a struct field child",
                 ))
             }
         }
@@ -626,12 +571,9 @@ mod tests {
         ];
         assert_eq!(&arrow_fields, converted_arrow_schema.fields());
 
-        let converted_arrow_schema = parquet_to_arrow_schema_by_columns(
-            &parquet_schema,
-            vec![0usize, 1usize],
-            None,
-        )
-        .unwrap();
+        let converted_arrow_schema =
+            parquet_to_arrow_schema_by_columns(&parquet_schema, 
ProjectionMask::all(), None)
+                .unwrap();
         assert_eq!(&arrow_fields, converted_arrow_schema.fields());
     }
 
@@ -1119,33 +1061,15 @@ mod tests {
         // required int64 leaf5;
 
         let parquet_schema = 
SchemaDescriptor::new(Arc::new(parquet_group_type));
+        let mask = ProjectionMask::leaves(&parquet_schema, [3, 0, 4, 4]);
         let converted_arrow_schema =
-            parquet_to_arrow_schema_by_columns(&parquet_schema, vec![0, 3, 4], 
None)
-                .unwrap();
+            parquet_to_arrow_schema_by_columns(&parquet_schema, mask, 
None).unwrap();
         let converted_fields = converted_arrow_schema.fields();
 
         assert_eq!(arrow_fields.len(), converted_fields.len());
         for i in 0..arrow_fields.len() {
             assert_eq!(arrow_fields[i], converted_fields[i]);
         }
-
-        let err =
-            parquet_to_arrow_schema_by_columns(&parquet_schema, vec![3, 2, 4], 
None)
-                .unwrap_err()
-                .to_string();
-
-        assert!(
-            err.contains("out of order projection is not supported"),
-            "{}",
-            err
-        );
-
-        let err =
-            parquet_to_arrow_schema_by_columns(&parquet_schema, vec![3, 3, 4], 
None)
-                .unwrap_err()
-                .to_string();
-
-        assert!(err.contains("repeated column projection is not supported, 
column 3 appeared multiple times"), "{}", err);
     }
 
     #[test]
@@ -1188,9 +1112,9 @@ mod tests {
         // required int64 leaf5;
 
         let parquet_schema = 
SchemaDescriptor::new(Arc::new(parquet_group_type));
+        let mask = ProjectionMask::leaves(&parquet_schema, [3, 0, 4]);
         let converted_arrow_schema =
-            parquet_to_arrow_schema_by_columns(&parquet_schema, vec![0, 3, 4], 
None)
-                .unwrap();
+            parquet_to_arrow_schema_by_columns(&parquet_schema, mask, 
None).unwrap();
         let converted_fields = converted_arrow_schema.fields();
 
         assert_eq!(arrow_fields.len(), converted_fields.len());
@@ -1681,8 +1605,7 @@ mod tests {
         assert_eq!(schema, read_schema);
 
         // read all fields by columns
-        let partial_read_schema =
-            arrow_reader.get_schema_by_columns(0..(schema.fields().len()), 
false)?;
+        let partial_read_schema = 
arrow_reader.get_schema_by_columns(ProjectionMask::all())?;
         assert_eq!(schema, partial_read_schema);
 
         Ok(())
@@ -1751,8 +1674,7 @@ mod tests {
         assert_eq!(schema, read_schema);
 
         // read all fields by columns
-        let partial_read_schema =
-            arrow_reader.get_schema_by_columns(0..(schema.fields().len()), 
false)?;
+        let partial_read_schema = 
arrow_reader.get_schema_by_columns(ProjectionMask::all())?;
         assert_eq!(schema, partial_read_schema);
 
         Ok(())
diff --git a/parquet/src/arrow/schema/complex.rs 
b/parquet/src/arrow/schema/complex.rs
index 31a9d6e82..d63ab5606 100644
--- a/parquet/src/arrow/schema/complex.rs
+++ b/parquet/src/arrow/schema/complex.rs
@@ -16,6 +16,7 @@
 // under the License.
 
 use crate::arrow::schema::primitive::convert_primitive;
+use crate::arrow::ProjectionMask;
 use crate::basic::{ConvertedType, Repetition};
 use crate::errors::ParquetError;
 use crate::errors::Result;
@@ -120,7 +121,7 @@ struct Visitor {
     next_col_idx: usize,
 
     /// Mask of columns to include
-    column_mask: Vec<bool>,
+    mask: ProjectionMask,
 }
 
 impl Visitor {
@@ -132,7 +133,7 @@ impl Visitor {
         let col_idx = self.next_col_idx;
         self.next_col_idx += 1;
 
-        if !self.column_mask[col_idx] {
+        if !self.mask.leaf_included(col_idx) {
             return Ok(None);
         }
 
@@ -549,28 +550,14 @@ fn convert_field(
 /// [`Schema`] embedded in the parquet metadata
 ///
 /// Note: This does not support out of order column projection
-pub fn convert_schema<T: IntoIterator<Item = usize>>(
+pub fn convert_schema(
     schema: &SchemaDescriptor,
-    leaf_columns: T,
+    mask: ProjectionMask,
     embedded_arrow_schema: Option<&Schema>,
 ) -> Result<Option<ParquetField>> {
-    let mut leaf_mask = vec![false; schema.num_columns()];
-    let mut last_idx = 0;
-    for i in leaf_columns {
-        if i < last_idx {
-            return Err(general_err!("out of order projection is not 
supported"));
-        }
-        if leaf_mask[i] {
-            return Err(general_err!("repeated column projection is not 
supported, column {} appeared multiple times", i));
-        }
-
-        last_idx = i;
-        leaf_mask[i] = true;
-    }
-
     let mut visitor = Visitor {
         next_col_idx: 0,
-        column_mask: leaf_mask,
+        mask,
     };
 
     let context = VisitorContext {
@@ -586,7 +573,7 @@ pub fn convert_schema<T: IntoIterator<Item = usize>>(
 pub fn convert_type(parquet_type: &TypePtr) -> Result<ParquetField> {
     let mut visitor = Visitor {
         next_col_idx: 0,
-        column_mask: vec![true],
+        mask: ProjectionMask::all(),
     };
 
     let context = VisitorContext {
diff --git a/parquet/src/schema/types.rs b/parquet/src/schema/types.rs
index 1785d2950..9cef93a69 100644
--- a/parquet/src/schema/types.rs
+++ b/parquet/src/schema/types.rs
@@ -847,13 +847,13 @@ pub struct SchemaDescriptor {
     // `schema` in DFS order.
     leaves: Vec<ColumnDescPtr>,
 
-    // Mapping from a leaf column's index to the root column type that it
+    // Mapping from a leaf column's index to the root column index that it
     // comes from. For instance: the leaf `a.b.c.d` would have a link back to 
`a`:
     // -- a  <-----+
     // -- -- b     |
     // -- -- -- c  |
     // -- -- -- -- d
-    leaf_to_base: Vec<TypePtr>,
+    leaf_to_base: Vec<usize>,
 }
 
 impl fmt::Debug for SchemaDescriptor {
@@ -871,9 +871,9 @@ impl SchemaDescriptor {
         assert!(tp.is_group(), "SchemaDescriptor should take a GroupType");
         let mut leaves = vec![];
         let mut leaf_to_base = Vec::new();
-        for f in tp.get_fields() {
+        for (root_idx, f) in tp.get_fields().iter().enumerate() {
             let mut path = vec![];
-            build_tree(f, f, 0, 0, &mut leaves, &mut leaf_to_base, &mut path);
+            build_tree(f, root_idx, 0, 0, &mut leaves, &mut leaf_to_base, &mut 
path);
         }
 
         Self {
@@ -904,30 +904,35 @@ impl SchemaDescriptor {
         self.leaves.len()
     }
 
-    /// Returns column root [`Type`](crate::schema::types::Type) for a field 
position.
+    /// Returns column root [`Type`](crate::schema::types::Type) for a leaf 
position.
     pub fn get_column_root(&self, i: usize) -> &Type {
         let result = self.column_root_of(i);
         result.as_ref()
     }
 
-    /// Returns column root [`Type`](crate::schema::types::Type) pointer for a 
field
+    /// Returns column root [`Type`](crate::schema::types::Type) pointer for a 
leaf
     /// position.
     pub fn get_column_root_ptr(&self, i: usize) -> TypePtr {
         let result = self.column_root_of(i);
         result.clone()
     }
 
-    fn column_root_of(&self, i: usize) -> &Arc<Type> {
+    /// Returns the index of the root column for a field position
+    pub fn get_column_root_idx(&self, leaf: usize) -> usize {
         assert!(
-            i < self.leaves.len(),
+            leaf < self.leaves.len(),
             "Index out of bound: {} not in [0, {})",
-            i,
+            leaf,
             self.leaves.len()
         );
 
-        self.leaf_to_base
-            .get(i)
-            .unwrap_or_else(|| panic!("Expected a value for index {} but found 
None", i))
+        *self.leaf_to_base.get(leaf).unwrap_or_else(|| {
+            panic!("Expected a value for index {} but found None", leaf)
+        })
+    }
+
+    fn column_root_of(&self, i: usize) -> &TypePtr {
+        &self.schema.get_fields()[self.get_column_root_idx(i)]
     }
 
     /// Returns schema as [`Type`](crate::schema::types::Type).
@@ -947,11 +952,11 @@ impl SchemaDescriptor {
 
 fn build_tree<'a>(
     tp: &'a TypePtr,
-    base_tp: &TypePtr,
+    root_idx: usize,
     mut max_rep_level: i16,
     mut max_def_level: i16,
     leaves: &mut Vec<ColumnDescPtr>,
-    leaf_to_base: &mut Vec<TypePtr>,
+    leaf_to_base: &mut Vec<usize>,
     path_so_far: &mut Vec<&'a str>,
 ) {
     assert!(tp.get_basic_info().has_repetition());
@@ -978,13 +983,13 @@ fn build_tree<'a>(
                 max_rep_level,
                 ColumnPath::new(path),
             )));
-            leaf_to_base.push(base_tp.clone());
+            leaf_to_base.push(root_idx);
         }
         Type::GroupType { ref fields, .. } => {
             for f in fields {
                 build_tree(
                     f,
-                    base_tp,
+                    root_idx,
                     max_rep_level,
                     max_def_level,
                     leaves,

Reply via email to