This is an automated email from the ASF dual-hosted git repository.

liurenjie1024 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-rust.git


The following commit(s) were added to refs/heads/main by this push:
     new 6e5a871  feat: Read Parquet data file with projection (#245)
6e5a871 is described below

commit 6e5a871ff3a9e2b7f30c59f9711d3e709378b62b
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Sun Mar 31 19:09:46 2024 -0700

    feat: Read Parquet data file with projection (#245)
    
    * feat: Read Parquet data file with projection
    
    * fix
    
    * Update
    
    * More
    
    * For review
    
    * Use FeatureUnsupported error.
---
 crates/iceberg/src/arrow.rs |  99 ++++++++++++++++++++++++++++++++++++++-----
 crates/iceberg/src/scan.rs  | 101 ++++++++++++++++++++++++++++++++++++++------
 2 files changed, 178 insertions(+), 22 deletions(-)

diff --git a/crates/iceberg/src/arrow.rs b/crates/iceberg/src/arrow.rs
index 47cbaa1..4b23df8 100644
--- a/crates/iceberg/src/arrow.rs
+++ b/crates/iceberg/src/arrow.rs
@@ -17,12 +17,16 @@
 
 //! Parquet file data reader
 
+use arrow_schema::SchemaRef as ArrowSchemaRef;
 use async_stream::try_stream;
 use futures::stream::StreamExt;
-use parquet::arrow::{ParquetRecordBatchStreamBuilder, ProjectionMask};
+use parquet::arrow::{ParquetRecordBatchStreamBuilder, ProjectionMask, 
PARQUET_FIELD_ID_META_KEY};
+use parquet::schema::types::SchemaDescriptor;
+use std::collections::HashMap;
+use std::str::FromStr;
 
 use crate::io::FileIO;
-use crate::scan::{ArrowRecordBatchStream, FileScanTask, FileScanTaskStream};
+use crate::scan::{ArrowRecordBatchStream, FileScanTaskStream};
 use crate::spec::SchemaRef;
 
 use crate::error::Result;
@@ -36,6 +40,7 @@ use std::sync::Arc;
 /// Builder to create ArrowReader
 pub struct ArrowReaderBuilder {
     batch_size: Option<usize>,
+    field_ids: Vec<usize>,
     file_io: FileIO,
     schema: SchemaRef,
 }
@@ -45,6 +50,7 @@ impl ArrowReaderBuilder {
     pub fn new(file_io: FileIO, schema: SchemaRef) -> Self {
         ArrowReaderBuilder {
             batch_size: None,
+            field_ids: vec![],
             file_io,
             schema,
         }
@@ -57,10 +63,17 @@ impl ArrowReaderBuilder {
         self
     }
 
+    /// Sets the desired column projection with a list of field ids.
+    pub fn with_field_ids(mut self, field_ids: impl IntoIterator<Item = 
usize>) -> Self {
+        self.field_ids = field_ids.into_iter().collect();
+        self
+    }
+
     /// Build the ArrowReader.
     pub fn build(self) -> ArrowReader {
         ArrowReader {
             batch_size: self.batch_size,
+            field_ids: self.field_ids,
             schema: self.schema,
             file_io: self.file_io,
         }
@@ -70,6 +83,7 @@ impl ArrowReaderBuilder {
 /// Reads data from Parquet files
 pub struct ArrowReader {
     batch_size: Option<usize>,
+    field_ids: Vec<usize>,
     #[allow(dead_code)]
     schema: SchemaRef,
     file_io: FileIO,
@@ -83,17 +97,18 @@ impl ArrowReader {
 
         Ok(try_stream! {
             while let Some(Ok(task)) = tasks.next().await {
-
-                let projection_mask = self.get_arrow_projection_mask(&task);
-
                 let parquet_reader = file_io
                     .new_input(task.data().data_file().file_path())?
                     .reader()
                     .await?;
 
                 let mut batch_stream_builder = 
ParquetRecordBatchStreamBuilder::new(parquet_reader)
-                    .await?
-                    .with_projection(projection_mask);
+                    .await?;
+
+                let parquet_schema = batch_stream_builder.parquet_schema();
+                let arrow_schema = batch_stream_builder.schema();
+                let projection_mask = 
self.get_arrow_projection_mask(parquet_schema, arrow_schema)?;
+                batch_stream_builder = 
batch_stream_builder.with_projection(projection_mask);
 
                 if let Some(batch_size) = self.batch_size {
                     batch_stream_builder = 
batch_stream_builder.with_batch_size(batch_size);
@@ -109,9 +124,73 @@ impl ArrowReader {
         .boxed())
     }
 
-    fn get_arrow_projection_mask(&self, _task: &FileScanTask) -> 
ProjectionMask {
-        // TODO: full implementation
-        ProjectionMask::all()
+    fn get_arrow_projection_mask(
+        &self,
+        parquet_schema: &SchemaDescriptor,
+        arrow_schema: &ArrowSchemaRef,
+    ) -> crate::Result<ProjectionMask> {
+        if self.field_ids.is_empty() {
+            Ok(ProjectionMask::all())
+        } else {
+            // Build the map between field id and column index in Parquet 
schema.
+            let mut column_map = HashMap::new();
+
+            let fields = arrow_schema.fields();
+            let iceberg_schema = arrow_schema_to_schema(arrow_schema)?;
+            fields.filter_leaves(|idx, field| {
+                let field_id = field.metadata().get(PARQUET_FIELD_ID_META_KEY);
+                if field_id.is_none() {
+                    return false;
+                }
+
+                let field_id = i32::from_str(field_id.unwrap());
+                if field_id.is_err() {
+                    return false;
+                }
+                let field_id = field_id.unwrap();
+
+                if !self.field_ids.contains(&(field_id as usize)) {
+                    return false;
+                }
+
+                let iceberg_field = self.schema.field_by_id(field_id);
+                let parquet_iceberg_field = 
iceberg_schema.field_by_id(field_id);
+
+                if iceberg_field.is_none() || parquet_iceberg_field.is_none() {
+                    return false;
+                }
+
+                if iceberg_field.unwrap().field_type != 
parquet_iceberg_field.unwrap().field_type {
+                    return false;
+                }
+
+                column_map.insert(field_id, idx);
+                true
+            });
+
+            if column_map.len() != self.field_ids.len() {
+                return Err(Error::new(
+                    ErrorKind::DataInvalid,
+                    format!(
+                        "Parquet schema {} and Iceberg schema {} do not 
match.",
+                        iceberg_schema, self.schema
+                    ),
+                ));
+            }
+
+            let mut indices = vec![];
+            for field_id in &self.field_ids {
+                if let Some(col_idx) = column_map.get(&(*field_id as i32)) {
+                    indices.push(*col_idx);
+                } else {
+                    return Err(Error::new(
+                        ErrorKind::DataInvalid,
+                        format!("Field {} is not found in Parquet schema.", 
field_id),
+                    ));
+                }
+            }
+            Ok(ProjectionMask::leaves(parquet_schema, indices))
+        }
     }
 }
 
diff --git a/crates/iceberg/src/scan.rs b/crates/iceberg/src/scan.rs
index 358de5d..b96c470 100644
--- a/crates/iceberg/src/scan.rs
+++ b/crates/iceberg/src/scan.rs
@@ -109,7 +109,10 @@ impl<'a> TableScanBuilder<'a> {
                 if schema.field_by_name(column_name).is_none() {
                     return Err(Error::new(
                         ErrorKind::DataInvalid,
-                        format!("Column {} not found in table.", column_name),
+                        format!(
+                            "Column {} not found in table. Schema: {}",
+                            column_name, schema
+                        ),
                     ));
                 }
             }
@@ -187,6 +190,46 @@ impl TableScan {
         let mut arrow_reader_builder =
             ArrowReaderBuilder::new(self.file_io.clone(), self.schema.clone());
 
+        let mut field_ids = vec![];
+        for column_name in &self.column_names {
+            let field_id = 
self.schema.field_id_by_name(column_name).ok_or_else(|| {
+                Error::new(
+                    ErrorKind::DataInvalid,
+                    format!(
+                        "Column {} not found in table. Schema: {}",
+                        column_name, self.schema
+                    ),
+                )
+            })?;
+
+            let field = self.schema
+                .as_struct()
+                .field_by_id(field_id)
+                .ok_or_else(|| {
+                    Error::new(
+                        ErrorKind::FeatureUnsupported,
+                        format!(
+                            "Column {} is not a direct child of schema but a 
nested field, which is not supported now. Schema: {}",
+                            column_name, self.schema
+                        ),
+                    )
+                })?;
+
+            if !field.field_type.is_primitive() {
+                return Err(Error::new(
+                    ErrorKind::FeatureUnsupported,
+                    format!(
+                        "Column {} is not a primitive type. Schema: {}",
+                        column_name, self.schema
+                    ),
+                ));
+            }
+
+            field_ids.push(field_id as usize);
+        }
+
+        arrow_reader_builder = arrow_reader_builder.with_field_ids(field_ids);
+
         if let Some(batch_size) = self.batch_size {
             arrow_reader_builder = 
arrow_reader_builder.with_batch_size(batch_size);
         }
@@ -390,18 +433,29 @@ mod tests {
 
             // prepare data
             let schema = {
-                let fields =
-                    vec![
-                        arrow_schema::Field::new("col", 
arrow_schema::DataType::Int64, true)
-                            .with_metadata(HashMap::from([(
-                                PARQUET_FIELD_ID_META_KEY.to_string(),
-                                "0".to_string(),
-                            )])),
-                    ];
+                let fields = vec![
+                    arrow_schema::Field::new("x", 
arrow_schema::DataType::Int64, false)
+                        .with_metadata(HashMap::from([(
+                            PARQUET_FIELD_ID_META_KEY.to_string(),
+                            "1".to_string(),
+                        )])),
+                    arrow_schema::Field::new("y", 
arrow_schema::DataType::Int64, false)
+                        .with_metadata(HashMap::from([(
+                            PARQUET_FIELD_ID_META_KEY.to_string(),
+                            "2".to_string(),
+                        )])),
+                    arrow_schema::Field::new("z", 
arrow_schema::DataType::Int64, false)
+                        .with_metadata(HashMap::from([(
+                            PARQUET_FIELD_ID_META_KEY.to_string(),
+                            "3".to_string(),
+                        )])),
+                ];
                 Arc::new(arrow_schema::Schema::new(fields))
             };
-            let col = Arc::new(Int64Array::from_iter_values(vec![1; 1024])) as 
ArrayRef;
-            let to_write = RecordBatch::try_new(schema.clone(), 
vec![col]).unwrap();
+            let col1 = Arc::new(Int64Array::from_iter_values(vec![1; 1024])) 
as ArrayRef;
+            let col2 = Arc::new(Int64Array::from_iter_values(vec![2; 1024])) 
as ArrayRef;
+            let col3 = Arc::new(Int64Array::from_iter_values(vec![3; 1024])) 
as ArrayRef;
+            let to_write = RecordBatch::try_new(schema.clone(), vec![col1, 
col2, col3]).unwrap();
 
             // Write the Parquet files
             let props = WriterProperties::builder()
@@ -531,9 +585,32 @@ mod tests {
 
         let batches: Vec<_> = batch_stream.try_collect().await.unwrap();
 
-        let col = batches[0].column_by_name("col").unwrap();
+        let col = batches[0].column_by_name("x").unwrap();
 
         let int64_arr = col.as_any().downcast_ref::<Int64Array>().unwrap();
         assert_eq!(int64_arr.value(0), 1);
     }
+
+    #[tokio::test]
+    async fn test_open_parquet_with_projection() {
+        let mut fixture = TableTestFixture::new();
+        fixture.setup_manifest_files().await;
+
+        // Create table scan for current snapshot and plan files
+        let table_scan = fixture.table.scan().select(["x", 
"z"]).build().unwrap();
+
+        let batch_stream = table_scan.to_arrow().await.unwrap();
+
+        let batches: Vec<_> = batch_stream.try_collect().await.unwrap();
+
+        assert_eq!(batches[0].num_columns(), 2);
+
+        let col1 = batches[0].column_by_name("x").unwrap();
+        let int64_arr = col1.as_any().downcast_ref::<Int64Array>().unwrap();
+        assert_eq!(int64_arr.value(0), 1);
+
+        let col2 = batches[0].column_by_name("z").unwrap();
+        let int64_arr = col2.as_any().downcast_ref::<Int64Array>().unwrap();
+        assert_eq!(int64_arr.value(0), 3);
+    }
 }

Reply via email to