This is an automated email from the ASF dual-hosted git repository.

nevime pushed a commit to branch rust-parquet-arrow-writer
in repository https://gitbox.apache.org/repos/asf/arrow.git

commit 12add4202d940806099d980e748bf1ce3099b2a9
Author: Carol (Nichols || Goulding) <carol.nich...@gmail.com>
AuthorDate: Thu Oct 8 00:16:42 2020 +0200

    ARROW-10168: [Rust] [Parquet] Schema roundtrip - use Arrow schema from 
Parquet metadata when available
    
    @nevi-me This is one commit on top of 
https://github.com/apache/arrow/pull/8330 that I'm opening to get some feedback 
from you on about whether this will help with ARROW-10168. I *think* this will 
bring the Rust implementation more in line with C++, but I'm not certain.
    
    I tried removing the `#[ignore]` attributes from the `LargeArray` and 
`LargeUtf8` tests, but they're still failing because the schemas don't match 
yet-- it looks like [this 
code](https://github.com/apache/arrow/blob/b2842ab2eb0d7a7a633049a5591e1eaa254d4446/rust/parquet/src/arrow/array_reader.rs#L595-L638)
 will need to be changed as well.
    
    That `build_array_reader` function's code looks very similar to the code 
I've changed here, is there a possibility for the code to be shared or is there 
a reason they're separate?
    
    Closes #8354 from carols10cents/schema-roundtrip
    
    Lead-authored-by: Carol (Nichols || Goulding) <carol.nich...@gmail.com>
    Co-authored-by: Neville Dipale <nevilled...@gmail.com>
    Signed-off-by: Neville Dipale <nevilled...@gmail.com>
---
 rust/arrow/src/ipc/convert.rs           |   4 +-
 rust/parquet/src/arrow/array_reader.rs  | 106 +++++++++++++----
 rust/parquet/src/arrow/arrow_reader.rs  |  36 ++++--
 rust/parquet/src/arrow/arrow_writer.rs  |   4 +-
 rust/parquet/src/arrow/converter.rs     |  52 +++++++-
 rust/parquet/src/arrow/mod.rs           |   3 +-
 rust/parquet/src/arrow/record_reader.rs |   1 +
 rust/parquet/src/arrow/schema.rs        | 205 +++++++++++++++++++++++++++-----
 8 files changed, 338 insertions(+), 73 deletions(-)

diff --git a/rust/arrow/src/ipc/convert.rs b/rust/arrow/src/ipc/convert.rs
index 8f429bf..a02b6c4 100644
--- a/rust/arrow/src/ipc/convert.rs
+++ b/rust/arrow/src/ipc/convert.rs
@@ -334,7 +334,9 @@ pub(crate) fn build_field<'a: 'b, 'b>(
 
     let mut field_builder = ipc::FieldBuilder::new(fbb);
     field_builder.add_name(fb_field_name);
-    fb_dictionary.map(|dictionary| field_builder.add_dictionary(dictionary));
+    if let Some(dictionary) = fb_dictionary {
+        field_builder.add_dictionary(dictionary)
+    }
     field_builder.add_type_type(field_type.type_type);
     field_builder.add_nullable(field.is_nullable());
     match field_type.children {
diff --git a/rust/parquet/src/arrow/array_reader.rs 
b/rust/parquet/src/arrow/array_reader.rs
index 4fbc54d..40df284 100644
--- a/rust/parquet/src/arrow/array_reader.rs
+++ b/rust/parquet/src/arrow/array_reader.rs
@@ -29,16 +29,20 @@ use arrow::array::{
     Int16BufferBuilder, StructArray,
 };
 use arrow::buffer::{Buffer, MutableBuffer};
-use arrow::datatypes::{DataType as ArrowType, DateUnit, Field, IntervalUnit, 
TimeUnit};
+use arrow::datatypes::{
+    DataType as ArrowType, DateUnit, Field, IntervalUnit, Schema, TimeUnit,
+};
 
 use crate::arrow::converter::{
     BinaryArrayConverter, BinaryConverter, BoolConverter, 
BooleanArrayConverter,
     Converter, Date32Converter, FixedLenBinaryConverter, 
FixedSizeArrayConverter,
     Float32Converter, Float64Converter, Int16Converter, Int32Converter, 
Int64Converter,
-    Int8Converter, Int96ArrayConverter, Int96Converter, 
Time32MillisecondConverter,
-    Time32SecondConverter, Time64MicrosecondConverter, 
Time64NanosecondConverter,
-    TimestampMicrosecondConverter, TimestampMillisecondConverter, 
UInt16Converter,
-    UInt32Converter, UInt64Converter, UInt8Converter, Utf8ArrayConverter, 
Utf8Converter,
+    Int8Converter, Int96ArrayConverter, Int96Converter, 
LargeBinaryArrayConverter,
+    LargeBinaryConverter, LargeUtf8ArrayConverter, LargeUtf8Converter,
+    Time32MillisecondConverter, Time32SecondConverter, 
Time64MicrosecondConverter,
+    Time64NanosecondConverter, TimestampMicrosecondConverter,
+    TimestampMillisecondConverter, UInt16Converter, UInt32Converter, 
UInt64Converter,
+    UInt8Converter, Utf8ArrayConverter, Utf8Converter,
 };
 use crate::arrow::record_reader::RecordReader;
 use crate::arrow::schema::parquet_to_arrow_field;
@@ -612,6 +616,7 @@ impl ArrayReader for StructArrayReader {
 /// Create array reader from parquet schema, column indices, and parquet file 
reader.
 pub fn build_array_reader<T>(
     parquet_schema: SchemaDescPtr,
+    arrow_schema: Schema,
     column_indices: T,
     file_reader: Rc<dyn FileReader>,
 ) -> Result<Box<dyn ArrayReader>>
@@ -650,13 +655,19 @@ where
         fields: filtered_root_fields,
     };
 
-    ArrayReaderBuilder::new(Rc::new(proj), Rc::new(leaves), file_reader)
-        .build_array_reader()
+    ArrayReaderBuilder::new(
+        Rc::new(proj),
+        Rc::new(arrow_schema),
+        Rc::new(leaves),
+        file_reader,
+    )
+    .build_array_reader()
 }
 
 /// Used to build array reader.
 struct ArrayReaderBuilder {
     root_schema: TypePtr,
+    arrow_schema: Rc<Schema>,
     // Key: columns that need to be included in final array builder
     // Value: column index in schema
     columns_included: Rc<HashMap<*const Type, usize>>,
@@ -790,11 +801,13 @@ impl<'a> ArrayReaderBuilder {
     /// Construct array reader builder.
     fn new(
         root_schema: TypePtr,
+        arrow_schema: Rc<Schema>,
         columns_included: Rc<HashMap<*const Type, usize>>,
         file_reader: Rc<dyn FileReader>,
     ) -> Self {
         Self {
             root_schema,
+            arrow_schema,
             columns_included,
             file_reader,
         }
@@ -835,6 +848,12 @@ impl<'a> ArrayReaderBuilder {
             self.file_reader.clone(),
         )?);
 
+        let arrow_type = self
+            .arrow_schema
+            .field_with_name(cur_type.name())
+            .ok()
+            .map(|f| f.data_type());
+
         match cur_type.get_physical_type() {
             PhysicalType::BOOLEAN => 
Ok(Box::new(PrimitiveArrayReader::<BoolType>::new(
                 page_iterator,
@@ -866,21 +885,43 @@ impl<'a> ArrayReaderBuilder {
             )),
             PhysicalType::BYTE_ARRAY => {
                 if cur_type.get_basic_info().logical_type() == 
LogicalType::UTF8 {
-                    let converter = Utf8Converter::new(Utf8ArrayConverter {});
-                    Ok(Box::new(ComplexObjectArrayReader::<
-                        ByteArrayType,
-                        Utf8Converter,
-                    >::new(
-                        page_iterator, column_desc, converter
-                    )?))
+                    if let Some(ArrowType::LargeUtf8) = arrow_type {
+                        let converter =
+                            LargeUtf8Converter::new(LargeUtf8ArrayConverter 
{});
+                        Ok(Box::new(ComplexObjectArrayReader::<
+                            ByteArrayType,
+                            LargeUtf8Converter,
+                        >::new(
+                            page_iterator, column_desc, converter
+                        )?))
+                    } else {
+                        let converter = Utf8Converter::new(Utf8ArrayConverter 
{});
+                        Ok(Box::new(ComplexObjectArrayReader::<
+                            ByteArrayType,
+                            Utf8Converter,
+                        >::new(
+                            page_iterator, column_desc, converter
+                        )?))
+                    }
                 } else {
-                    let converter = BinaryConverter::new(BinaryArrayConverter 
{});
-                    Ok(Box::new(ComplexObjectArrayReader::<
-                        ByteArrayType,
-                        BinaryConverter,
-                    >::new(
-                        page_iterator, column_desc, converter
-                    )?))
+                    if let Some(ArrowType::LargeBinary) = arrow_type {
+                        let converter =
+                            
LargeBinaryConverter::new(LargeBinaryArrayConverter {});
+                        Ok(Box::new(ComplexObjectArrayReader::<
+                            ByteArrayType,
+                            LargeBinaryConverter,
+                        >::new(
+                            page_iterator, column_desc, converter
+                        )?))
+                    } else {
+                        let converter = 
BinaryConverter::new(BinaryArrayConverter {});
+                        Ok(Box::new(ComplexObjectArrayReader::<
+                            ByteArrayType,
+                            BinaryConverter,
+                        >::new(
+                            page_iterator, column_desc, converter
+                        )?))
+                    }
                 }
             }
             PhysicalType::FIXED_LEN_BYTE_ARRAY => {
@@ -918,11 +959,15 @@ impl<'a> ArrayReaderBuilder {
 
         for child in cur_type.get_fields() {
             if let Some(child_reader) = self.dispatch(child.clone(), context)? 
{
-                fields.push(Field::new(
-                    child.name(),
-                    child_reader.get_data_type().clone(),
-                    child.is_optional(),
-                ));
+                let field = match 
self.arrow_schema.field_with_name(child.name()) {
+                    Ok(f) => f.to_owned(),
+                    _ => Field::new(
+                        child.name(),
+                        child_reader.get_data_type().clone(),
+                        child.is_optional(),
+                    ),
+                };
+                fields.push(field);
                 children_reader.push(child_reader);
             }
         }
@@ -945,6 +990,7 @@ impl<'a> ArrayReaderBuilder {
 mod tests {
     use super::*;
     use crate::arrow::converter::Utf8Converter;
+    use crate::arrow::schema::parquet_to_arrow_schema;
     use crate::basic::{Encoding, Type as PhysicalType};
     use crate::column::page::{Page, PageReader};
     use crate::data_type::{ByteArray, DataType, Int32Type, Int64Type};
@@ -1591,8 +1637,16 @@ mod tests {
         let file = get_test_file("nulls.snappy.parquet");
         let file_reader = Rc::new(SerializedFileReader::new(file).unwrap());
 
+        let file_metadata = file_reader.metadata().file_metadata();
+        let arrow_schema = parquet_to_arrow_schema(
+            file_metadata.schema_descr(),
+            file_metadata.key_value_metadata(),
+        )
+        .unwrap();
+
         let array_reader = build_array_reader(
             file_reader.metadata().file_metadata().schema_descr_ptr(),
+            arrow_schema,
             vec![0usize].into_iter(),
             file_reader,
         )
diff --git a/rust/parquet/src/arrow/arrow_reader.rs 
b/rust/parquet/src/arrow/arrow_reader.rs
index b654de1..88af583 100644
--- a/rust/parquet/src/arrow/arrow_reader.rs
+++ b/rust/parquet/src/arrow/arrow_reader.rs
@@ -19,7 +19,9 @@
 
 use crate::arrow::array_reader::{build_array_reader, ArrayReader, 
StructArrayReader};
 use crate::arrow::schema::parquet_to_arrow_schema;
-use crate::arrow::schema::parquet_to_arrow_schema_by_columns;
+use crate::arrow::schema::{
+    parquet_to_arrow_schema_by_columns, 
parquet_to_arrow_schema_by_root_columns,
+};
 use crate::errors::{ParquetError, Result};
 use crate::file::reader::FileReader;
 use arrow::datatypes::{DataType as ArrowType, Schema, SchemaRef};
@@ -40,7 +42,12 @@ pub trait ArrowReader {
 
     /// Read parquet schema and convert it into arrow schema.
     /// This schema only includes columns identified by `column_indices`.
-    fn get_schema_by_columns<T>(&mut self, column_indices: T) -> Result<Schema>
+    /// To select leaf columns (i.e. `a.b.c` instead of `a`), set 
`leaf_columns = true`
+    fn get_schema_by_columns<T>(
+        &mut self,
+        column_indices: T,
+        leaf_columns: bool,
+    ) -> Result<Schema>
     where
         T: IntoIterator<Item = usize>;
 
@@ -84,16 +91,28 @@ impl ArrowReader for ParquetFileArrowReader {
         )
     }
 
-    fn get_schema_by_columns<T>(&mut self, column_indices: T) -> Result<Schema>
+    fn get_schema_by_columns<T>(
+        &mut self,
+        column_indices: T,
+        leaf_columns: bool,
+    ) -> Result<Schema>
     where
         T: IntoIterator<Item = usize>,
     {
         let file_metadata = self.file_reader.metadata().file_metadata();
-        parquet_to_arrow_schema_by_columns(
-            file_metadata.schema_descr(),
-            column_indices,
-            file_metadata.key_value_metadata(),
-        )
+        if leaf_columns {
+            parquet_to_arrow_schema_by_columns(
+                file_metadata.schema_descr(),
+                column_indices,
+                file_metadata.key_value_metadata(),
+            )
+        } else {
+            parquet_to_arrow_schema_by_root_columns(
+                file_metadata.schema_descr(),
+                column_indices,
+                file_metadata.key_value_metadata(),
+            )
+        }
     }
 
     fn get_record_reader(
@@ -123,6 +142,7 @@ impl ArrowReader for ParquetFileArrowReader {
                 .metadata()
                 .file_metadata()
                 .schema_descr_ptr(),
+            self.get_schema()?,
             column_indices,
             self.file_reader.clone(),
         )?;
diff --git a/rust/parquet/src/arrow/arrow_writer.rs 
b/rust/parquet/src/arrow/arrow_writer.rs
index cf7b9a2..40e2553 100644
--- a/rust/parquet/src/arrow/arrow_writer.rs
+++ b/rust/parquet/src/arrow/arrow_writer.rs
@@ -1012,7 +1012,7 @@ mod tests {
     }
 
     #[test]
-    #[ignore] // Large Binary support isn't correct yet
+    #[ignore] // Large binary support isn't correct yet - buffers don't match
     fn large_binary_single_column() {
         let one_vec: Vec<u8> = (0..SMALL_SIZE as u8).collect();
         let many_vecs: Vec<_> = 
std::iter::repeat(one_vec).take(SMALL_SIZE).collect();
@@ -1035,7 +1035,7 @@ mod tests {
     }
 
     #[test]
-    #[ignore] // Large String support isn't correct yet - null_bitmap and 
buffers don't match
+    #[ignore] // Large string support isn't correct yet - null_bitmap doesn't 
match
     fn large_string_single_column() {
         let raw_values: Vec<_> = (0..SMALL_SIZE).map(|i| 
i.to_string()).collect();
         let raw_strs = raw_values.iter().map(|s| s.as_str());
diff --git a/rust/parquet/src/arrow/converter.rs 
b/rust/parquet/src/arrow/converter.rs
index c988aae..64bd833 100644
--- a/rust/parquet/src/arrow/converter.rs
+++ b/rust/parquet/src/arrow/converter.rs
@@ -21,8 +21,8 @@ use crate::data_type::{ByteArray, DataType, Int96};
 use arrow::{
     array::{
         Array, ArrayRef, BinaryBuilder, BooleanArray, BooleanBufferBuilder,
-        BufferBuilderTrait, FixedSizeBinaryBuilder, StringBuilder,
-        TimestampNanosecondBuilder,
+        BufferBuilderTrait, FixedSizeBinaryBuilder, LargeBinaryBuilder,
+        LargeStringBuilder, StringBuilder, TimestampNanosecondBuilder,
     },
     datatypes::Time32MillisecondType,
 };
@@ -38,8 +38,8 @@ use arrow::datatypes::{ArrowPrimitiveType, DataType as 
ArrowDataType};
 
 use arrow::array::ArrayDataBuilder;
 use arrow::array::{
-    BinaryArray, FixedSizeBinaryArray, PrimitiveArray, StringArray,
-    TimestampNanosecondArray,
+    BinaryArray, FixedSizeBinaryArray, LargeBinaryArray, LargeStringArray,
+    PrimitiveArray, StringArray, TimestampNanosecondArray,
 };
 use std::marker::PhantomData;
 
@@ -200,6 +200,27 @@ impl Converter<Vec<Option<ByteArray>>, StringArray> for 
Utf8ArrayConverter {
     }
 }
 
+pub struct LargeUtf8ArrayConverter {}
+
+impl Converter<Vec<Option<ByteArray>>, LargeStringArray> for 
LargeUtf8ArrayConverter {
+    fn convert(&self, source: Vec<Option<ByteArray>>) -> 
Result<LargeStringArray> {
+        let data_size = source
+            .iter()
+            .map(|x| x.as_ref().map(|b| b.len()).unwrap_or(0))
+            .sum();
+
+        let mut builder = LargeStringBuilder::with_capacity(source.len(), 
data_size);
+        for v in source {
+            match v {
+                Some(array) => builder.append_value(array.as_utf8()?),
+                None => builder.append_null(),
+            }?
+        }
+
+        Ok(builder.finish())
+    }
+}
+
 pub struct BinaryArrayConverter {}
 
 impl Converter<Vec<Option<ByteArray>>, BinaryArray> for BinaryArrayConverter {
@@ -216,6 +237,22 @@ impl Converter<Vec<Option<ByteArray>>, BinaryArray> for 
BinaryArrayConverter {
     }
 }
 
+pub struct LargeBinaryArrayConverter {}
+
+impl Converter<Vec<Option<ByteArray>>, LargeBinaryArray> for 
LargeBinaryArrayConverter {
+    fn convert(&self, source: Vec<Option<ByteArray>>) -> 
Result<LargeBinaryArray> {
+        let mut builder = LargeBinaryBuilder::new(source.len());
+        for v in source {
+            match v {
+                Some(array) => builder.append_value(array.data()),
+                None => builder.append_null(),
+            }?
+        }
+
+        Ok(builder.finish())
+    }
+}
+
 pub type BoolConverter<'a> = ArrayRefConverter<
     &'a mut RecordReader<BoolType>,
     BooleanArray,
@@ -246,8 +283,15 @@ pub type Float32Converter = 
CastConverter<ParquetFloatType, Float32Type, Float32
 pub type Float64Converter = CastConverter<ParquetDoubleType, Float64Type, 
Float64Type>;
 pub type Utf8Converter =
     ArrayRefConverter<Vec<Option<ByteArray>>, StringArray, Utf8ArrayConverter>;
+pub type LargeUtf8Converter =
+    ArrayRefConverter<Vec<Option<ByteArray>>, LargeStringArray, 
LargeUtf8ArrayConverter>;
 pub type BinaryConverter =
     ArrayRefConverter<Vec<Option<ByteArray>>, BinaryArray, 
BinaryArrayConverter>;
+pub type LargeBinaryConverter = ArrayRefConverter<
+    Vec<Option<ByteArray>>,
+    LargeBinaryArray,
+    LargeBinaryArrayConverter,
+>;
 pub type Int96Converter =
     ArrayRefConverter<Vec<Option<Int96>>, TimestampNanosecondArray, 
Int96ArrayConverter>;
 pub type FixedLenBinaryConverter = ArrayRefConverter<
diff --git a/rust/parquet/src/arrow/mod.rs b/rust/parquet/src/arrow/mod.rs
index 2b012fb..9793457 100644
--- a/rust/parquet/src/arrow/mod.rs
+++ b/rust/parquet/src/arrow/mod.rs
@@ -35,7 +35,7 @@
 //!
 //! println!("Converted arrow schema is: {}", 
arrow_reader.get_schema().unwrap());
 //! println!("Arrow schema after projection is: {}",
-//!    arrow_reader.get_schema_by_columns(vec![2, 4, 6]).unwrap());
+//!    arrow_reader.get_schema_by_columns(vec![2, 4, 6], true).unwrap());
 //!
 //! let mut record_batch_reader = 
arrow_reader.get_record_reader(2048).unwrap();
 //!
@@ -61,6 +61,7 @@ pub use self::arrow_reader::ParquetFileArrowReader;
 pub use self::arrow_writer::ArrowWriter;
 pub use self::schema::{
     arrow_to_parquet_schema, parquet_to_arrow_schema, 
parquet_to_arrow_schema_by_columns,
+    parquet_to_arrow_schema_by_root_columns,
 };
 
 /// Schema metadata key used to store serialized Arrow IPC schema
diff --git a/rust/parquet/src/arrow/record_reader.rs 
b/rust/parquet/src/arrow/record_reader.rs
index ccfdaf8..b30ab77 100644
--- a/rust/parquet/src/arrow/record_reader.rs
+++ b/rust/parquet/src/arrow/record_reader.rs
@@ -86,6 +86,7 @@ impl<'a, T> FatPtr<'a, T> {
         self.ptr
     }
 
+    #[allow(clippy::wrong_self_convention)]
     fn to_slice_mut(&mut self) -> &mut [T] {
         self.ptr
     }
diff --git a/rust/parquet/src/arrow/schema.rs b/rust/parquet/src/arrow/schema.rs
index 4a92a46..0cd41fe 100644
--- a/rust/parquet/src/arrow/schema.rs
+++ b/rust/parquet/src/arrow/schema.rs
@@ -56,7 +56,61 @@ pub fn parquet_to_arrow_schema(
     }
 }
 
-/// Convert parquet schema to arrow schema including optional metadata, only 
preserving some leaf columns.
+/// Convert parquet schema to arrow schema including optional metadata,
+/// only preserving some root columns.
+/// This is useful if we have columns `a.b`, `a.c.e` and `a.d`,
+/// and want `a` with all its child fields
+pub fn parquet_to_arrow_schema_by_root_columns<T>(
+    parquet_schema: &SchemaDescriptor,
+    column_indices: T,
+    key_value_metadata: &Option<Vec<KeyValue>>,
+) -> Result<Schema>
+where
+    T: IntoIterator<Item = usize>,
+{
+    // Reconstruct the index ranges of the parent columns
+    // An Arrow struct gets represented by 1+ columns based on how many child 
fields the
+    // struct has. This means that getting fields 1 and 2 might return the 
struct twice,
+    // if field 1 is the struct having say 3 fields, and field 2 is a 
primitive.
+    //
+    // The below gets the parent columns, and counts the number of child 
fields in each parent,
+    // such that we would end up with:
+    // - field 1 - columns: [0, 1, 2]
+    // - field 2 - columns: [3]
+    let mut parent_columns = vec![];
+    let mut curr_name = "";
+    let mut prev_name = "";
+    let mut indices = vec![];
+    (0..(parquet_schema.num_columns())).for_each(|i| {
+        let p_type = parquet_schema.get_column_root(i);
+        curr_name = p_type.get_basic_info().name();
+        if prev_name == "" {
+            // first index
+            indices.push(i);
+            prev_name = curr_name;
+        } else if curr_name != prev_name {
+            prev_name = curr_name;
+            parent_columns.push((curr_name.to_string(), indices.clone()));
+            indices = vec![i];
+        } else {
+            indices.push(i);
+        }
+    });
+    // push the last column if indices has values
+    if !indices.is_empty() {
+        parent_columns.push((curr_name.to_string(), indices));
+    }
+
+    // gather the required leaf columns
+    let leaf_columns = column_indices
+        .into_iter()
+        .flat_map(|i| parent_columns[i].1.clone());
+
+    parquet_to_arrow_schema_by_columns(parquet_schema, leaf_columns, 
key_value_metadata)
+}
+
+/// Convert parquet schema to arrow schema including optional metadata,
+/// only preserving some leaf columns.
 pub fn parquet_to_arrow_schema_by_columns<T>(
     parquet_schema: &SchemaDescriptor,
     column_indices: T,
@@ -65,27 +119,56 @@ pub fn parquet_to_arrow_schema_by_columns<T>(
 where
     T: IntoIterator<Item = usize>,
 {
+    let mut metadata = 
parse_key_value_metadata(key_value_metadata).unwrap_or_default();
+    let arrow_schema_metadata = metadata
+        .remove(super::ARROW_SCHEMA_META_KEY)
+        .map(|encoded| get_arrow_schema_from_metadata(&encoded))
+        .unwrap_or_default();
+
+    // add the Arrow metadata to the Parquet metadata
+    if let Some(arrow_schema) = &arrow_schema_metadata {
+        arrow_schema.metadata().iter().for_each(|(k, v)| {
+            metadata.insert(k.clone(), v.clone());
+        });
+    }
+
     let mut base_nodes = Vec::new();
     let mut base_nodes_set = HashSet::new();
     let mut leaves = HashSet::new();
 
+    enum FieldType<'a> {
+        Parquet(&'a Type),
+        Arrow(Field),
+    }
+
     for c in column_indices {
-        let column = parquet_schema.column(c).self_type() as *const Type;
-        let root = parquet_schema.get_column_root(c);
-        let root_raw_ptr = root as *const Type;
-
-        leaves.insert(column);
-        if !base_nodes_set.contains(&root_raw_ptr) {
-            base_nodes.push(root);
-            base_nodes_set.insert(root_raw_ptr);
+        let column = parquet_schema.column(c);
+        let name = column.name();
+
+        if let Some(field) = arrow_schema_metadata
+            .as_ref()
+            .and_then(|schema| schema.field_with_name(name).ok().cloned())
+        {
+            base_nodes.push(FieldType::Arrow(field));
+        } else {
+            let column = column.self_type() as *const Type;
+            let root = parquet_schema.get_column_root(c);
+            let root_raw_ptr = root as *const Type;
+
+            leaves.insert(column);
+            if !base_nodes_set.contains(&root_raw_ptr) {
+                base_nodes.push(FieldType::Parquet(root));
+                base_nodes_set.insert(root_raw_ptr);
+            }
         }
     }
 
-    let metadata = 
parse_key_value_metadata(key_value_metadata).unwrap_or_default();
-
     base_nodes
         .into_iter()
-        .map(|t| ParquetTypeConverter::new(t, &leaves).to_field())
+        .map(|t| match t {
+            FieldType::Parquet(t) => ParquetTypeConverter::new(t, 
&leaves).to_field(),
+            FieldType::Arrow(f) => Ok(Some(f)),
+        })
         .collect::<Result<Vec<Option<Field>>>>()
         .map(|result| result.into_iter().filter_map(|f| 
f).collect::<Vec<Field>>())
         .map(|fields| Schema::new_with_metadata(fields, metadata))
@@ -1367,21 +1450,21 @@ mod tests {
                 Field::new("c19", DataType::Interval(IntervalUnit::DayTime), 
false),
                 Field::new("c20", DataType::Interval(IntervalUnit::YearMonth), 
false),
                 Field::new("c21", DataType::List(Box::new(DataType::Boolean)), 
false),
-                Field::new(
-                    "c22",
-                    DataType::FixedSizeList(Box::new(DataType::Boolean), 5),
-                    false,
-                ),
-                Field::new(
-                    "c23",
-                    
DataType::List(Box::new(DataType::List(Box::new(DataType::Struct(
-                        vec![
-                            Field::new("a", DataType::Int16, true),
-                            Field::new("b", DataType::Float64, false),
-                        ],
-                    ))))),
-                    true,
-                ),
+                // Field::new(
+                //     "c22",
+                //     DataType::FixedSizeList(Box::new(DataType::Boolean), 5),
+                //     false,
+                // ),
+                // Field::new(
+                //     "c23",
+                //     DataType::List(Box::new(DataType::LargeList(Box::new(
+                //         DataType::Struct(vec![
+                //             Field::new("a", DataType::Int16, true),
+                //             Field::new("b", DataType::Float64, false),
+                //         ]),
+                //     )))),
+                //     true,
+                // ),
                 Field::new(
                     "c24",
                     DataType::Struct(vec![
@@ -1408,12 +1491,66 @@ mod tests {
                 ),
                 Field::new("c32", DataType::LargeBinary, true),
                 Field::new("c33", DataType::LargeUtf8, true),
+                // Field::new(
+                //     "c34",
+                //     DataType::LargeList(Box::new(DataType::List(Box::new(
+                //         DataType::Struct(vec![
+                //             Field::new("a", DataType::Int16, true),
+                //             Field::new("b", DataType::Float64, true),
+                //         ]),
+                //     )))),
+                //     true,
+                // ),
+            ],
+            metadata,
+        );
+
+        // write to an empty parquet file so that schema is serialized
+        let file = get_temp_file("test_arrow_schema_roundtrip.parquet", &[]);
+        let mut writer = ArrowWriter::try_new(
+            file.try_clone().unwrap(),
+            Arc::new(schema.clone()),
+            None,
+        )?;
+        writer.close()?;
+
+        // read file back
+        let parquet_reader = SerializedFileReader::try_from(file)?;
+        let mut arrow_reader = 
ParquetFileArrowReader::new(Rc::new(parquet_reader));
+        let read_schema = arrow_reader.get_schema()?;
+        assert_eq!(schema, read_schema);
+
+        // read all fields by columns
+        let partial_read_schema =
+            arrow_reader.get_schema_by_columns(0..(schema.fields().len()), 
false)?;
+        assert_eq!(schema, partial_read_schema);
+
+        Ok(())
+    }
+
+    #[test]
+    #[ignore = "Roundtrip of lists currently fails because we don't check 
their types correctly in the Arrow schema"]
+    fn test_arrow_schema_roundtrip_lists() -> Result<()> {
+        let metadata: HashMap<String, String> =
+            [("Key".to_string(), "Value".to_string())]
+                .iter()
+                .cloned()
+                .collect();
+
+        let schema = Schema::new_with_metadata(
+            vec![
+                Field::new("c21", DataType::List(Box::new(DataType::Boolean)), 
false),
                 Field::new(
-                    "c34",
-                    DataType::LargeList(Box::new(DataType::LargeList(Box::new(
+                    "c22",
+                    DataType::FixedSizeList(Box::new(DataType::Boolean), 5),
+                    false,
+                ),
+                Field::new(
+                    "c23",
+                    DataType::List(Box::new(DataType::LargeList(Box::new(
                         DataType::Struct(vec![
                             Field::new("a", DataType::Int16, true),
-                            Field::new("b", DataType::Float64, true),
+                            Field::new("b", DataType::Float64, false),
                         ]),
                     )))),
                     true,
@@ -1423,7 +1560,7 @@ mod tests {
         );
 
         // write to an empty parquet file so that schema is serialized
-        let file = get_temp_file("test_arrow_schema_roundtrip.parquet", &[]);
+        let file = get_temp_file("test_arrow_schema_roundtrip_lists.parquet", 
&[]);
         let mut writer = ArrowWriter::try_new(
             file.try_clone().unwrap(),
             Arc::new(schema.clone()),
@@ -1436,6 +1573,12 @@ mod tests {
         let mut arrow_reader = 
ParquetFileArrowReader::new(Rc::new(parquet_reader));
         let read_schema = arrow_reader.get_schema()?;
         assert_eq!(schema, read_schema);
+
+        // read all fields by columns
+        let partial_read_schema =
+            arrow_reader.get_schema_by_columns(0..(schema.fields().len()), 
false)?;
+        assert_eq!(schema, partial_read_schema);
+
         Ok(())
     }
 }

Reply via email to