This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new d9aaa437ca Add `RecordReader` trait and proc macro to implement it for 
a struct (#4773)
d9aaa437ca is described below

commit d9aaa437ca4ebf5a3500c865272243612862c7d4
Author: Joseph Rance <[email protected]>
AuthorDate: Mon Oct 30 11:40:34 2023 +0000

    Add `RecordReader` trait and proc macro to implement it for a struct (#4773)
    
    * add and implement RecordReader trait for rust structs
    
    * Fix typo in comment
    
    * run cargo fmt
    
    * partially solve issues raised in review
    
    * remove references
    
    * change interface to use vectors
    
    * change interface to use vectors in  as well
    
    * update comments
    
    * remove intitialisation requirement
    
    * prevent conflicts with existing default implementation
    
    * update documentation
    
    * run cargo fmt
    
    * change writer back to slice
    
    * change 'Handle' back to 'Derive' for RecordWriter macro in readme
    
    ---------
    
    Co-authored-by: joseph rance <[email protected]>
---
 parquet/src/record/mod.rs                          |   2 +
 .../record/{record_writer.rs => record_reader.rs}  |  19 +-
 parquet/src/record/record_writer.rs                |   4 +
 parquet_derive/README.md                           |  51 +++-
 parquet_derive/src/lib.rs                          |  88 +++++-
 parquet_derive/src/parquet_field.rs                | 338 +++++++++++++++++++--
 parquet_derive_test/src/lib.rs                     |  70 ++++-
 7 files changed, 532 insertions(+), 40 deletions(-)

diff --git a/parquet/src/record/mod.rs b/parquet/src/record/mod.rs
index 771d8058c9..f40e91418d 100644
--- a/parquet/src/record/mod.rs
+++ b/parquet/src/record/mod.rs
@@ -19,6 +19,7 @@
 
 mod api;
 pub mod reader;
+mod record_reader;
 mod record_writer;
 mod triplet;
 
@@ -26,5 +27,6 @@ pub use self::{
     api::{
         Field, List, ListAccessor, Map, MapAccessor, Row, RowAccessor, 
RowColumnIter, RowFormatter,
     },
+    record_reader::RecordReader,
     record_writer::RecordWriter,
 };
diff --git a/parquet/src/record/record_writer.rs 
b/parquet/src/record/record_reader.rs
similarity index 69%
copy from parquet/src/record/record_writer.rs
copy to parquet/src/record/record_reader.rs
index 62099051f5..bcfeb95dcd 100644
--- a/parquet/src/record/record_writer.rs
+++ b/parquet/src/record/record_reader.rs
@@ -15,17 +15,16 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use crate::schema::types::TypePtr;
-
 use super::super::errors::ParquetError;
-use super::super::file::writer::SerializedRowGroupWriter;
+use super::super::file::reader::RowGroupReader;
 
-pub trait RecordWriter<T> {
-    fn write_to_row_group<W: std::io::Write + Send>(
-        &self,
-        row_group_writer: &mut SerializedRowGroupWriter<W>,
+/// read up to `max_records` records from `row_group_reader` into `self`
+/// The type parameter `T` is used to work around the rust orphan rule
+/// when implementing on types such as `Vec<T>`.
+pub trait RecordReader<T> {
+    fn read_from_row_group(
+        &mut self,
+        row_group_reader: &mut dyn RowGroupReader,
+        num_records: usize,
     ) -> Result<(), ParquetError>;
-
-    /// Generated schema
-    fn schema(&self) -> Result<TypePtr, ParquetError>;
 }
diff --git a/parquet/src/record/record_writer.rs 
b/parquet/src/record/record_writer.rs
index 62099051f5..0b2b95ef7d 100644
--- a/parquet/src/record/record_writer.rs
+++ b/parquet/src/record/record_writer.rs
@@ -20,6 +20,10 @@ use crate::schema::types::TypePtr;
 use super::super::errors::ParquetError;
 use super::super::file::writer::SerializedRowGroupWriter;
 
+/// `write_to_row_group` writes from `self` into `row_group_writer`
+/// `schema` builds the schema used by `row_group_writer`
+/// The type parameter `T` is used to work around the rust orphan rule
+/// when implementing on types such as `&[T]`.
 pub trait RecordWriter<T> {
     fn write_to_row_group<W: std::io::Write + Send>(
         &self,
diff --git a/parquet_derive/README.md b/parquet_derive/README.md
index b20721079c..c267a92430 100644
--- a/parquet_derive/README.md
+++ b/parquet_derive/README.md
@@ -19,9 +19,9 @@
 
 # Parquet Derive
 
-A crate for deriving `RecordWriter` for arbitrary, _simple_ structs. This does 
not generate writers for arbitrarily nested
-structures. It only works for primitives and a few generic structures and
-various levels of reference. Please see features checklist for what is 
currently
+A crate for deriving `RecordWriter` and `RecordReader` for arbitrary, _simple_ 
structs. This does not
+generate readers or writers for arbitrarily nested structures. It only works 
for primitives and a few
+generic structures and various levels of reference. Please see features 
checklist for what is currently
 supported.
 
 Derive also has some support for the chrono time library. You must must enable 
the `chrono` feature to get this support.
@@ -77,16 +77,55 @@ writer.close_row_group(row_group).unwrap();
 writer.close().unwrap();
 ```
 
+Example usage of deriving a `RecordReader` for your struct:
+
+```rust
+use parquet::file::{serialized_reader::SerializedFileReader, 
reader::FileReader};
+use parquet_derive::ParquetRecordReader;
+
+#[derive(ParquetRecordReader)]
+struct ACompleteRecord {
+    pub a_bool: bool,
+    pub a_string: String,
+    pub i16: i16,
+    pub i32: i32,
+    pub u64: u64,
+    pub isize: isize,
+    pub float: f32,
+    pub double: f64,
+    pub now: chrono::NaiveDateTime,
+    pub byte_vec: Vec<u8>,
+}
+
+// Initialize your parquet file
+let reader = SerializedFileReader::new(file).unwrap();
+let mut row_group = reader.get_row_group(0).unwrap();
+
+// create your records vector to read into
+let mut chunks: Vec<ACompleteRecord> = Vec::new();
+
+// The derived `RecordReader` takes over here
+chunks.read_from_row_group(&mut *row_group, 1).unwrap();
+```
+
 ## Features
 
 - [x] Support writing `String`, `&str`, `bool`, `i32`, `f32`, `f64`, `Vec<u8>`
 - [ ] Support writing dictionaries
 - [x] Support writing logical types like timestamp
-- [x] Derive definition_levels for `Option`
-- [ ] Derive definition levels for nested structures
+- [x] Derive definition_levels for `Option` for writing
+- [ ] Derive definition levels for nested structures for writing
 - [ ] Derive writing tuple struct
 - [ ] Derive writing `tuple` container types
 
+- [x] Support reading `String`, `&str`, `bool`, `i32`, `f32`, `f64`, `Vec<u8>`
+- [ ] Support reading/writing dictionaries
+- [x] Support reading/writing logical types like timestamp
+- [ ] Handle definition_levels for `Option` for reading
+- [ ] Handle definition levels for nested structures for reading
+- [ ] Derive reading/writing tuple struct
+- [ ] Derive reading/writing `tuple` container types
+
 ## Requirements
 
 - Same as `parquet-rs`
@@ -103,4 +142,4 @@ To compile and view in the browser, run `cargo doc 
--no-deps --open`.
 
 ## License
 
-Licensed under the Apache License, Version 2.0: 
http://www.apache.org/licenses/LICENSE-2.0.
+Licensed under the Apache License, Version 2.0: 
http://www.apache.org/licenses/LICENSE-2.0.
\ No newline at end of file
diff --git a/parquet_derive/src/lib.rs b/parquet_derive/src/lib.rs
index c6641cd809..671a46db0f 100644
--- a/parquet_derive/src/lib.rs
+++ b/parquet_derive/src/lib.rs
@@ -44,7 +44,7 @@ mod parquet_field;
 /// use parquet::file::writer::SerializedFileWriter;
 ///
 /// use std::sync::Arc;
-//
+///
 /// #[derive(ParquetRecordWriter)]
 /// struct ACompleteRecord<'a> {
 ///   pub a_bool: bool,
@@ -137,3 +137,89 @@ pub fn parquet_record_writer(input: 
proc_macro::TokenStream) -> proc_macro::Toke
     }
   }).into()
 }
+
+/// Derive flat, simple RecordReader implementations. Works by parsing
+/// a struct tagged with `#[derive(ParquetRecordReader)]` and emitting
+/// the correct writing code for each field of the struct. Column readers
+/// are generated in the order they are defined.
+///
+/// It is up to the programmer to keep the order of the struct
+/// fields lined up with the schema.
+///
+/// Example:
+///
+/// ```ignore
+/// use parquet::file::{serialized_reader::SerializedFileReader, 
reader::FileReader};
+/// use parquet_derive::{ParquetRecordReader};
+///
+/// #[derive(ParquetRecordReader)]
+/// struct ACompleteRecord {
+///     pub a_bool: bool,
+///     pub a_string: String,
+/// }
+///
+/// pub fn read_some_records() -> Vec<ACompleteRecord> {
+///   let mut samples: Vec<ACompleteRecord> = Vec::new();
+///
+///   let reader = SerializedFileReader::new(file).unwrap();
+///   let mut row_group = reader.get_row_group(0).unwrap();
+///   samples.read_from_row_group(&mut *row_group, 1).unwrap();
+///   samples
+/// }
+/// ```
+///
+#[proc_macro_derive(ParquetRecordReader)]
+pub fn parquet_record_reader(input: proc_macro::TokenStream) -> 
proc_macro::TokenStream {
+    let input: DeriveInput = parse_macro_input!(input as DeriveInput);
+    let fields = match input.data {
+        Data::Struct(DataStruct { fields, .. }) => fields,
+        Data::Enum(_) => unimplemented!("Enum currently is not supported"),
+        Data::Union(_) => unimplemented!("Union currently is not supported"),
+    };
+
+    let field_infos: Vec<_> = 
fields.iter().map(parquet_field::Field::from).collect();
+    let field_names: Vec<_> = fields.iter().map(|f| f.ident.clone()).collect();
+    let reader_snippets: Vec<proc_macro2::TokenStream> =
+        field_infos.iter().map(|x| x.reader_snippet()).collect();
+    let i: Vec<_> = (0..reader_snippets.len()).collect();
+
+    let derived_for = input.ident;
+    let generics = input.generics;
+
+    (quote! {
+
+    impl #generics ::parquet::record::RecordReader<#derived_for #generics> for 
Vec<#derived_for #generics> {
+      fn read_from_row_group(
+        &mut self,
+        row_group_reader: &mut dyn ::parquet::file::reader::RowGroupReader,
+        num_records: usize,
+      ) -> Result<(), ::parquet::errors::ParquetError> {
+        use ::parquet::column::reader::ColumnReader;
+
+        let mut row_group_reader = row_group_reader;
+
+        for _ in 0..num_records {
+          self.push(#derived_for {
+            #(
+              #field_names: Default::default()
+            ),*
+          })
+        }
+
+        let records = self; // Used by all the reader snippets to be more clear
+
+        #(
+          {
+              if let Ok(mut column_reader) = 
row_group_reader.get_column_reader(#i) {
+                  #reader_snippets
+              } else {
+                  return Err(::parquet::errors::ParquetError::General("Failed 
to get next column".into()))
+              }
+          }
+        );*
+
+        Ok(())
+      }
+    }
+  }).into()
+}
diff --git a/parquet_derive/src/parquet_field.rs 
b/parquet_derive/src/parquet_field.rs
index e629bfe757..0ac95c2864 100644
--- a/parquet_derive/src/parquet_field.rs
+++ b/parquet_derive/src/parquet_field.rs
@@ -219,6 +219,72 @@ impl Field {
         }
     }
 
+    /// Takes the parsed field of the struct and emits a valid
+    /// column reader snippet. Should match exactly what you
+    /// would write by hand.
+    ///
+    /// Can only generate writers for basic structs, for example:
+    ///
+    /// struct Record {
+    ///   a_bool: bool
+    /// }
+    ///
+    /// but not
+    ///
+    /// struct UnsupportedNestedRecord {
+    ///   a_property: bool,
+    ///   nested_record: Record
+    /// }
+    ///
+    /// because this parsing logic is not sophisticated enough for definition
+    /// levels beyond 2.
+    ///
+    /// `Option` types and references not supported
+    pub fn reader_snippet(&self) -> proc_macro2::TokenStream {
+        let ident = &self.ident;
+        let column_reader = self.ty.column_reader();
+        let parquet_type = self.ty.physical_type_as_rust();
+
+        // generate the code to read the column into a vector `vals`
+        let write_batch_expr = quote! {
+            let mut vals_vec = Vec::new();
+            vals_vec.resize(num_records, Default::default());
+            let mut vals: &mut [#parquet_type] = vals_vec.as_mut_slice();
+            if let #column_reader(mut typed) = column_reader {
+                typed.read_records(num_records, None, None, vals)?;
+            } else {
+                panic!("Schema and struct disagree on type for {}", 
stringify!{#ident});
+            }
+        };
+
+        // generate the code to convert each element of `vals` to the correct 
type and then write
+        // it to its field in the corresponding struct
+        let vals_writer = match &self.ty {
+            Type::TypePath(_) => self.copied_direct_fields(),
+            Type::Reference(_, ref first_type) => match **first_type {
+                Type::TypePath(_) => self.copied_direct_fields(),
+                Type::Slice(ref second_type) => match **second_type {
+                    Type::TypePath(_) => self.copied_direct_fields(),
+                    ref f => unimplemented!("Unsupported: {:#?}", f),
+                },
+                ref f => unimplemented!("Unsupported: {:#?}", f),
+            },
+            Type::Vec(ref first_type) => match **first_type {
+                Type::TypePath(_) => self.copied_direct_fields(),
+                ref f => unimplemented!("Unsupported: {:#?}", f),
+            },
+            f => unimplemented!("Unsupported: {:#?}", f),
+        };
+
+        quote! {
+            {
+                #write_batch_expr
+
+                #vals_writer
+            }
+        }
+    }
+
     pub fn parquet_type(&self) -> proc_macro2::TokenStream {
         // TODO: Support group types
         // TODO: Add length if dealing with fixedlenbinary
@@ -319,27 +385,31 @@ impl Field {
         }
     }
 
+    // generates code to read `field_name` from each record into a vector 
`vals`
     fn copied_direct_vals(&self) -> proc_macro2::TokenStream {
         let field_name = &self.ident;
-        let is_a_byte_buf = self.is_a_byte_buf;
-        let is_a_timestamp = self.third_party_type == 
Some(ThirdPartyType::ChronoNaiveDateTime);
-        let is_a_date = self.third_party_type == 
Some(ThirdPartyType::ChronoNaiveDate);
-        let is_a_uuid = self.third_party_type == Some(ThirdPartyType::Uuid);
 
-        let access = if is_a_timestamp {
-            quote! { rec.#field_name.timestamp_millis() }
-        } else if is_a_date {
-            quote! { 
rec.#field_name.signed_duration_since(::chrono::NaiveDate::from_ymd(1970, 1, 
1)).num_days() as i32 }
-        } else if is_a_uuid {
-            quote! { (&rec.#field_name.to_string()[..]).into() }
-        } else if is_a_byte_buf {
-            quote! { (&rec.#field_name[..]).into() }
-        } else {
-            // Type might need converting to a physical type
-            match self.ty.physical_type() {
-                parquet::basic::Type::INT32 => quote! { rec.#field_name as i32 
},
-                parquet::basic::Type::INT64 => quote! { rec.#field_name as i64 
},
-                _ => quote! { rec.#field_name },
+        let access = match self.third_party_type {
+            Some(ThirdPartyType::ChronoNaiveDateTime) => {
+                quote! { rec.#field_name.timestamp_millis() }
+            }
+            Some(ThirdPartyType::ChronoNaiveDate) => {
+                quote! { 
rec.#field_name.signed_duration_since(::chrono::NaiveDate::from_ymd(1970, 1, 
1)).num_days() as i32 }
+            }
+            Some(ThirdPartyType::Uuid) => {
+                quote! { (&rec.#field_name.to_string()[..]).into() }
+            }
+            _ => {
+                if self.is_a_byte_buf {
+                    quote! { (&rec.#field_name[..]).into() }
+                } else {
+                    // Type might need converting to a physical type
+                    match self.ty.physical_type() {
+                        parquet::basic::Type::INT32 => quote! { 
rec.#field_name as i32 },
+                        parquet::basic::Type::INT64 => quote! { 
rec.#field_name as i64 },
+                        _ => quote! { rec.#field_name },
+                    }
+                }
             }
         };
 
@@ -348,6 +418,48 @@ impl Field {
         }
     }
 
+    // generates code to read a vector `records` into `field_name` for each 
record
+    fn copied_direct_fields(&self) -> proc_macro2::TokenStream {
+        let field_name = &self.ident;
+
+        let value = match self.third_party_type {
+            Some(ThirdPartyType::ChronoNaiveDateTime) => {
+                quote! { 
::chrono::naive::NaiveDateTime::from_timestamp_millis(vals[i]).unwrap() }
+            }
+            Some(ThirdPartyType::ChronoNaiveDate) => {
+                quote! {
+                    
::chrono::naive::NaiveDate::from_num_days_from_ce_opt(vals[i]
+                + ((::chrono::naive::NaiveDate::from_ymd_opt(1970, 1, 
1).unwrap()
+                        .signed_duration_since(
+                            ::chrono::naive::NaiveDate::from_ymd_opt(0, 12, 
31).unwrap()
+                        )
+                   ).num_days()) as i32).unwrap()
+                }
+            }
+            Some(ThirdPartyType::Uuid) => {
+                quote! { 
::uuid::Uuid::parse_str(vals[i].data().convert()).unwrap() }
+            }
+            _ => match &self.ty {
+                Type::TypePath(_) => match self.ty.last_part().as_str() {
+                    "String" => quote! { 
String::from(std::str::from_utf8(vals[i].data())
+                    .expect("invalid UTF-8 sequence")) },
+                    t => {
+                        let s: proc_macro2::TokenStream = t.parse().unwrap();
+                        quote! { vals[i] as #s }
+                    }
+                },
+                Type::Vec(_) => quote! { vals[i].data().to_vec() },
+                f => unimplemented!("Unsupported: {:#?}", f),
+            },
+        };
+
+        quote! {
+            for (i, r) in &mut records[..num_records].iter_mut().enumerate() {
+                r.#field_name = #value;
+            }
+        }
+    }
+
     fn optional_definition_levels(&self) -> proc_macro2::TokenStream {
         let field_name = &self.ident;
 
@@ -396,6 +508,29 @@ impl Type {
         }
     }
 
+    /// Takes a rust type and returns the appropriate
+    /// parquet-rs column reader
+    fn column_reader(&self) -> syn::TypePath {
+        use parquet::basic::Type as BasicType;
+
+        match self.physical_type() {
+            BasicType::BOOLEAN => {
+                syn::parse_quote!(ColumnReader::BoolColumnReader)
+            }
+            BasicType::INT32 => 
syn::parse_quote!(ColumnReader::Int32ColumnReader),
+            BasicType::INT64 => 
syn::parse_quote!(ColumnReader::Int64ColumnReader),
+            BasicType::INT96 => 
syn::parse_quote!(ColumnReader::Int96ColumnReader),
+            BasicType::FLOAT => 
syn::parse_quote!(ColumnReader::FloatColumnReader),
+            BasicType::DOUBLE => 
syn::parse_quote!(ColumnReader::DoubleColumnReader),
+            BasicType::BYTE_ARRAY => {
+                syn::parse_quote!(ColumnReader::ByteArrayColumnReader)
+            }
+            BasicType::FIXED_LEN_BYTE_ARRAY => {
+                syn::parse_quote!(ColumnReader::FixedLenByteArrayColumnReader)
+            }
+        }
+    }
+
     /// Helper to simplify a nested field definition to its leaf type
     ///
     /// Ex:
@@ -515,6 +650,23 @@ impl Type {
         }
     }
 
+    fn physical_type_as_rust(&self) -> proc_macro2::TokenStream {
+        use parquet::basic::Type as BasicType;
+
+        match self.physical_type() {
+            BasicType::BOOLEAN => quote! { bool },
+            BasicType::INT32 => quote! { i32 },
+            BasicType::INT64 => quote! { i64 },
+            BasicType::INT96 => unimplemented!("96-bit int currently is not 
supported"),
+            BasicType::FLOAT => quote! { f32 },
+            BasicType::DOUBLE => quote! { f64 },
+            BasicType::BYTE_ARRAY => quote! { ::parquet::data_type::ByteArray 
},
+            BasicType::FIXED_LEN_BYTE_ARRAY => {
+                quote! { ::parquet::data_type::FixedLenByteArray }
+            }
+        }
+    }
+
     fn logical_type(&self) -> proc_macro2::TokenStream {
         let last_part = self.last_part();
         let leaf_type = self.leaf_type_recursive();
@@ -713,6 +865,39 @@ mod test {
         )
     }
 
+    #[test]
+    fn test_generating_a_simple_reader_snippet() {
+        let snippet: proc_macro2::TokenStream = quote! {
+          struct ABoringStruct {
+            counter: usize,
+          }
+        };
+
+        let fields = extract_fields(snippet);
+        let counter = Field::from(&fields[0]);
+
+        let snippet = counter.reader_snippet().to_string();
+        assert_eq!(
+            snippet,
+            (quote! {
+                 {
+                     let mut vals_vec = Vec::new();
+                     vals_vec.resize(num_records, Default::default());
+                     let mut vals: &mut[i64] = vals_vec.as_mut_slice();
+                     if let ColumnReader::Int64ColumnReader(mut typed) = 
column_reader {
+                         typed.read_records(num_records, None, None, vals)?;
+                     } else {
+                         panic!("Schema and struct disagree on type for {}", 
stringify!{ counter });
+                     }
+                     for (i, r) in &mut 
records[..num_records].iter_mut().enumerate() {
+                         r.counter = vals[i] as usize;
+                     }
+                 }
+            })
+            .to_string()
+        )
+    }
+
     #[test]
     fn test_optional_to_writer_snippet() {
         let struct_def: proc_macro2::TokenStream = quote! {
@@ -822,6 +1007,32 @@ mod test {
         );
     }
 
+    #[test]
+    fn test_converting_to_column_reader_type() {
+        let snippet: proc_macro2::TokenStream = quote! {
+          struct ABasicStruct {
+            yes_no: bool,
+            name: String,
+          }
+        };
+
+        let fields = extract_fields(snippet);
+        let processed: Vec<_> = fields.iter().map(Field::from).collect();
+
+        let column_readers: Vec<_> = processed
+            .iter()
+            .map(|field| field.ty.column_reader())
+            .collect();
+
+        assert_eq!(
+            column_readers,
+            vec![
+                syn::parse_quote!(ColumnReader::BoolColumnReader),
+                syn::parse_quote!(ColumnReader::ByteArrayColumnReader)
+            ]
+        );
+    }
+
     #[test]
     fn convert_basic_struct() {
         let snippet: proc_macro2::TokenStream = quote! {
@@ -995,7 +1206,7 @@ mod test {
     }
 
     #[test]
-    fn test_chrono_timestamp_millis() {
+    fn test_chrono_timestamp_millis_write() {
         let snippet: proc_macro2::TokenStream = quote! {
           struct ATimestampStruct {
             henceforth: chrono::NaiveDateTime,
@@ -1038,7 +1249,34 @@ mod test {
     }
 
     #[test]
-    fn test_chrono_date() {
+    fn test_chrono_timestamp_millis_read() {
+        let snippet: proc_macro2::TokenStream = quote! {
+          struct ATimestampStruct {
+            henceforth: chrono::NaiveDateTime,
+          }
+        };
+
+        let fields = extract_fields(snippet);
+        let when = Field::from(&fields[0]);
+        assert_eq!(when.reader_snippet().to_string(),(quote!{
+            {
+                let mut vals_vec = Vec::new();
+                vals_vec.resize(num_records, Default::default());
+                let mut vals: &mut[i64] = vals_vec.as_mut_slice();
+                if let ColumnReader::Int64ColumnReader(mut typed) = 
column_reader {
+                    typed.read_records(num_records, None, None, vals)?;
+                } else {
+                    panic!("Schema and struct disagree on type for {}", 
stringify!{ henceforth });
+                }
+                for (i, r) in &mut 
records[..num_records].iter_mut().enumerate() {
+                    r.henceforth = 
::chrono::naive::NaiveDateTime::from_timestamp_millis(vals[i]).unwrap();
+                }
+            }
+        }).to_string());
+    }
+
+    #[test]
+    fn test_chrono_date_write() {
         let snippet: proc_macro2::TokenStream = quote! {
           struct ATimestampStruct {
             henceforth: chrono::NaiveDate,
@@ -1081,7 +1319,38 @@ mod test {
     }
 
     #[test]
-    fn test_uuid() {
+    fn test_chrono_date_read() {
+        let snippet: proc_macro2::TokenStream = quote! {
+          struct ATimestampStruct {
+            henceforth: chrono::NaiveDate,
+          }
+        };
+
+        let fields = extract_fields(snippet);
+        let when = Field::from(&fields[0]);
+        assert_eq!(when.reader_snippet().to_string(),(quote!{
+            {
+                let mut vals_vec = Vec::new();
+                vals_vec.resize(num_records, Default::default());
+                let mut vals: &mut [i32] = vals_vec.as_mut_slice();
+                if let ColumnReader::Int32ColumnReader(mut typed) = 
column_reader {
+                    typed.read_records(num_records, None, None, vals)?;
+                } else {
+                    panic!("Schema and struct disagree on type for {}", 
stringify!{ henceforth });
+                }
+                for (i, r) in &mut 
records[..num_records].iter_mut().enumerate() {
+                    r.henceforth = 
::chrono::naive::NaiveDate::from_num_days_from_ce_opt(vals[i]
+                        + ((::chrono::naive::NaiveDate::from_ymd_opt(1970, 1, 
1).unwrap()
+                        .signed_duration_since(
+                            ::chrono::naive::NaiveDate::from_ymd_opt(0, 12, 
31).unwrap()
+                        )).num_days()) as i32).unwrap();
+                }
+            }
+        }).to_string());
+    }
+
+    #[test]
+    fn test_uuid_write() {
         let snippet: proc_macro2::TokenStream = quote! {
           struct AUuidStruct {
             unique_id: uuid::Uuid,
@@ -1123,6 +1392,33 @@ mod test {
         }).to_string());
     }
 
+    #[test]
+    fn test_uuid_read() {
+        let snippet: proc_macro2::TokenStream = quote! {
+          struct AUuidStruct {
+            unique_id: uuid::Uuid,
+          }
+        };
+
+        let fields = extract_fields(snippet);
+        let when = Field::from(&fields[0]);
+        assert_eq!(when.reader_snippet().to_string(),(quote!{
+            {
+                let mut vals_vec = Vec::new();
+                vals_vec.resize(num_records, Default::default());
+                let mut vals: &mut [::parquet::data_type::ByteArray] = 
vals_vec.as_mut_slice();
+                if let ColumnReader::ByteArrayColumnReader(mut typed) = 
column_reader {
+                    typed.read_records(num_records, None, None, vals)?;
+                } else {
+                    panic!("Schema and struct disagree on type for {}", 
stringify!{ unique_id });
+                }
+                for (i, r) in &mut 
records[..num_records].iter_mut().enumerate() {
+                    r.unique_id = 
::uuid::Uuid::parse_str(vals[i].data().convert()).unwrap();
+                }
+            }
+        }).to_string());
+    }
+
     #[test]
     fn test_converted_type() {
         let snippet: proc_macro2::TokenStream = quote! {
diff --git a/parquet_derive_test/src/lib.rs b/parquet_derive_test/src/lib.rs
index d377fb0a62..a8b631ecc0 100644
--- a/parquet_derive_test/src/lib.rs
+++ b/parquet_derive_test/src/lib.rs
@@ -17,7 +17,7 @@
 
 #![allow(clippy::approx_constant)]
 
-use parquet_derive::ParquetRecordWriter;
+use parquet_derive::{ParquetRecordReader, ParquetRecordWriter};
 
 #[derive(ParquetRecordWriter)]
 struct ACompleteRecord<'a> {
@@ -49,6 +49,21 @@ struct ACompleteRecord<'a> {
     pub borrowed_maybe_borrowed_byte_vec: &'a Option<&'a [u8]>,
 }
 
+#[derive(PartialEq, ParquetRecordWriter, ParquetRecordReader, Debug)]
+struct APartiallyCompleteRecord {
+    pub bool: bool,
+    pub string: String,
+    pub i16: i16,
+    pub i32: i32,
+    pub u64: u64,
+    pub isize: isize,
+    pub float: f32,
+    pub double: f64,
+    pub now: chrono::NaiveDateTime,
+    pub date: chrono::NaiveDate,
+    pub byte_vec: Vec<u8>,
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;
@@ -56,7 +71,8 @@ mod tests {
     use std::{env, fs, io::Write, sync::Arc};
 
     use parquet::{
-        file::writer::SerializedFileWriter, record::RecordWriter,
+        file::writer::SerializedFileWriter,
+        record::{RecordReader, RecordWriter},
         schema::parser::parse_message_type,
     };
 
@@ -147,6 +163,56 @@ mod tests {
         writer.close().unwrap();
     }
 
+    #[test]
+    fn test_parquet_derive_read_write_combined() {
+        let file = get_temp_file("test_parquet_derive_combined", &[]);
+
+        let mut drs: Vec<APartiallyCompleteRecord> = 
vec![APartiallyCompleteRecord {
+            bool: true,
+            string: "a string".into(),
+            i16: -45,
+            i32: 456,
+            u64: 4563424,
+            isize: -365,
+            float: 3.5,
+            double: std::f64::NAN,
+            now: chrono::Utc::now().naive_local(),
+            date: chrono::naive::NaiveDate::from_ymd_opt(2015, 3, 14).unwrap(),
+            byte_vec: vec![0x65, 0x66, 0x67],
+        }];
+
+        let mut out: Vec<APartiallyCompleteRecord> = Vec::new();
+
+        use parquet::file::{reader::FileReader, 
serialized_reader::SerializedFileReader};
+
+        let generated_schema = drs.as_slice().schema().unwrap();
+
+        let props = Default::default();
+        let mut writer =
+            SerializedFileWriter::new(file.try_clone().unwrap(), 
generated_schema, props).unwrap();
+
+        let mut row_group = writer.next_row_group().unwrap();
+        drs.as_slice().write_to_row_group(&mut row_group).unwrap();
+        row_group.close().unwrap();
+        writer.close().unwrap();
+
+        let reader = SerializedFileReader::new(file).unwrap();
+
+        let mut row_group = reader.get_row_group(0).unwrap();
+        out.read_from_row_group(&mut *row_group, 1).unwrap();
+
+        // correct for rounding error when writing milliseconds
+        drs[0].now =
+            
chrono::naive::NaiveDateTime::from_timestamp_millis(drs[0].now.timestamp_millis())
+                .unwrap();
+
+        assert!(out[0].double.is_nan()); // these three lines are necessary 
because NAN != NAN
+        out[0].double = 0.;
+        drs[0].double = 0.;
+
+        assert_eq!(drs[0], out[0]);
+    }
+
     /// Returns file handle for a temp file in 'target' directory with a 
provided content
     pub fn get_temp_file(file_name: &str, content: &[u8]) -> fs::File {
         // build tmp path to a file in "target/debug/testdata"

Reply via email to