This is an automated email from the ASF dual-hosted git repository.

liurenjie1024 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-rust.git


The following commit(s) were added to refs/heads/main by this push:
     new 854171d  Enhancement: refine the reader interface (#401)
854171d is described below

commit 854171d42199f756d2ad1a81a742ce61d9299f05
Author: ZENOTME <[email protected]>
AuthorDate: Fri Jun 21 08:25:51 2024 +0800

    Enhancement: refine the reader interface (#401)
---
 crates/iceberg/src/arrow/reader.rs |  69 +++++--------
 crates/iceberg/src/scan.rs         | 203 +++++++++++++++++++++++++++----------
 crates/iceberg/src/spec/schema.rs  |   2 +-
 crates/iceberg/src/table.rs        |  11 ++
 4 files changed, 188 insertions(+), 97 deletions(-)

diff --git a/crates/iceberg/src/arrow/reader.rs 
b/crates/iceberg/src/arrow/reader.rs
index 7d04a59..f63817f 100644
--- a/crates/iceberg/src/arrow/reader.rs
+++ b/crates/iceberg/src/arrow/reader.rs
@@ -44,27 +44,21 @@ use crate::expr::visitors::bound_predicate_visitor::{visit, 
BoundPredicateVisito
 use crate::expr::{BoundPredicate, BoundReference};
 use crate::io::{FileIO, FileMetadata, FileRead};
 use crate::scan::{ArrowRecordBatchStream, FileScanTaskStream};
-use crate::spec::{Datum, SchemaRef};
+use crate::spec::{Datum, Schema};
 use crate::{Error, ErrorKind};
 
 /// Builder to create ArrowReader
 pub struct ArrowReaderBuilder {
     batch_size: Option<usize>,
-    field_ids: Vec<usize>,
     file_io: FileIO,
-    schema: SchemaRef,
-    predicates: Option<BoundPredicate>,
 }
 
 impl ArrowReaderBuilder {
     /// Create a new ArrowReaderBuilder
-    pub fn new(file_io: FileIO, schema: SchemaRef) -> Self {
+    pub(crate) fn new(file_io: FileIO) -> Self {
         ArrowReaderBuilder {
             batch_size: None,
-            field_ids: vec![],
             file_io,
-            schema,
-            predicates: None,
         }
     }
 
@@ -75,38 +69,20 @@ impl ArrowReaderBuilder {
         self
     }
 
-    /// Sets the desired column projection with a list of field ids.
-    pub fn with_field_ids(mut self, field_ids: impl IntoIterator<Item = 
usize>) -> Self {
-        self.field_ids = field_ids.into_iter().collect();
-        self
-    }
-
-    /// Sets the predicates to apply to the scan.
-    pub fn with_predicates(mut self, predicates: BoundPredicate) -> Self {
-        self.predicates = Some(predicates);
-        self
-    }
-
     /// Build the ArrowReader.
     pub fn build(self) -> ArrowReader {
         ArrowReader {
             batch_size: self.batch_size,
-            field_ids: self.field_ids,
-            schema: self.schema,
             file_io: self.file_io,
-            predicates: self.predicates,
         }
     }
 }
 
 /// Reads data from Parquet files
+#[derive(Clone)]
 pub struct ArrowReader {
     batch_size: Option<usize>,
-    field_ids: Vec<usize>,
-    #[allow(dead_code)]
-    schema: SchemaRef,
     file_io: FileIO,
-    predicates: Option<BoundPredicate>,
 }
 
 impl ArrowReader {
@@ -115,16 +91,16 @@ impl ArrowReader {
     pub fn read(self, mut tasks: FileScanTaskStream) -> 
crate::Result<ArrowRecordBatchStream> {
         let file_io = self.file_io.clone();
 
-        // Collect Parquet column indices from field ids
-        let mut collector = CollectFieldIdVisitor {
-            field_ids: HashSet::default(),
-        };
-        if let Some(predicates) = &self.predicates {
-            visit(&mut collector, predicates)?;
-        }
-
         Ok(try_stream! {
             while let Some(Ok(task)) = tasks.next().await {
+                // Collect Parquet column indices from field ids
+                let mut collector = CollectFieldIdVisitor {
+                    field_ids: HashSet::default(),
+                };
+                if let Some(predicates) = task.predicate() {
+                    visit(&mut collector, predicates)?;
+                }
+
                 let parquet_file = file_io
                     .new_input(task.data_file_path())?;
                 let (parquet_metadata, parquet_reader) = 
try_join!(parquet_file.metadata(), parquet_file.reader())?;
@@ -135,11 +111,11 @@ impl ArrowReader {
 
                 let parquet_schema = batch_stream_builder.parquet_schema();
                 let arrow_schema = batch_stream_builder.schema();
-                let projection_mask = 
self.get_arrow_projection_mask(parquet_schema, arrow_schema)?;
+                let projection_mask = 
self.get_arrow_projection_mask(task.project_field_ids(),task.schema(),parquet_schema,
 arrow_schema)?;
                 batch_stream_builder = 
batch_stream_builder.with_projection(projection_mask);
 
                 let parquet_schema = batch_stream_builder.parquet_schema();
-                let row_filter = self.get_row_filter(parquet_schema, 
&collector)?;
+                let row_filter = 
self.get_row_filter(task.predicate(),parquet_schema, &collector)?;
 
                 if let Some(row_filter) = row_filter {
                     batch_stream_builder = 
batch_stream_builder.with_row_filter(row_filter);
@@ -161,10 +137,12 @@ impl ArrowReader {
 
     fn get_arrow_projection_mask(
         &self,
+        field_ids: &[i32],
+        iceberg_schema_of_task: &Schema,
         parquet_schema: &SchemaDescriptor,
         arrow_schema: &ArrowSchemaRef,
     ) -> crate::Result<ProjectionMask> {
-        if self.field_ids.is_empty() {
+        if field_ids.is_empty() {
             Ok(ProjectionMask::all())
         } else {
             // Build the map between field id and column index in Parquet 
schema.
@@ -184,11 +162,11 @@ impl ArrowReader {
                 }
                 let field_id = field_id.unwrap();
 
-                if !self.field_ids.contains(&(field_id as usize)) {
+                if !field_ids.contains(&field_id) {
                     return false;
                 }
 
-                let iceberg_field = self.schema.field_by_id(field_id);
+                let iceberg_field = 
iceberg_schema_of_task.field_by_id(field_id);
                 let parquet_iceberg_field = 
iceberg_schema.field_by_id(field_id);
 
                 if iceberg_field.is_none() || parquet_iceberg_field.is_none() {
@@ -203,19 +181,19 @@ impl ArrowReader {
                 true
             });
 
-            if column_map.len() != self.field_ids.len() {
+            if column_map.len() != field_ids.len() {
                 return Err(Error::new(
                     ErrorKind::DataInvalid,
                     format!(
                         "Parquet schema {} and Iceberg schema {} do not 
match.",
-                        iceberg_schema, self.schema
+                        iceberg_schema, iceberg_schema_of_task
                     ),
                 ));
             }
 
             let mut indices = vec![];
-            for field_id in &self.field_ids {
-                if let Some(col_idx) = column_map.get(&(*field_id as i32)) {
+            for field_id in field_ids {
+                if let Some(col_idx) = column_map.get(field_id) {
                     indices.push(*col_idx);
                 } else {
                     return Err(Error::new(
@@ -230,10 +208,11 @@ impl ArrowReader {
 
     fn get_row_filter(
         &self,
+        predicates: Option<&BoundPredicate>,
         parquet_schema: &SchemaDescriptor,
         collector: &CollectFieldIdVisitor,
     ) -> Result<Option<RowFilter>> {
-        if let Some(predicates) = &self.predicates {
+        if let Some(predicates) = predicates {
             let field_id_map = build_field_id_map(parquet_schema)?;
 
             // Collect Parquet column indices from field ids.
diff --git a/crates/iceberg/src/scan.rs b/crates/iceberg/src/scan.rs
index 286cf57..8d3ef10 100644
--- a/crates/iceberg/src/scan.rs
+++ b/crates/iceberg/src/scan.rs
@@ -159,11 +159,50 @@ impl<'a> TableScanBuilder<'a> {
             None
         };
 
+        let mut field_ids = vec![];
+        for column_name in &self.column_names {
+            let field_id = schema.field_id_by_name(column_name).ok_or_else(|| {
+                Error::new(
+                    ErrorKind::DataInvalid,
+                    format!(
+                        "Column {} not found in table. Schema: {}",
+                        column_name, schema
+                    ),
+                )
+            })?;
+
+            let field = schema
+                .as_struct()
+                .field_by_id(field_id)
+                .ok_or_else(|| {
+                    Error::new(
+                        ErrorKind::FeatureUnsupported,
+                        format!(
+                            "Column {} is not a direct child of schema but a 
nested field, which is not supported now. Schema: {}",
+                            column_name, schema
+                        ),
+                    )
+                })?;
+
+            if !field.field_type.is_primitive() {
+                return Err(Error::new(
+                    ErrorKind::FeatureUnsupported,
+                    format!(
+                        "Column {} is not a primitive type. Schema: {}",
+                        column_name, schema
+                    ),
+                ));
+            }
+
+            field_ids.push(field_id);
+        }
+
         Ok(TableScan {
             snapshot,
             file_io: self.table.file_io().clone(),
             table_metadata: self.table.metadata_ref(),
             column_names: self.column_names,
+            field_ids,
             bound_predicates,
             schema,
             batch_size: self.batch_size,
@@ -181,6 +220,7 @@ pub struct TableScan {
     table_metadata: TableMetadataRef,
     file_io: FileIO,
     column_names: Vec<String>,
+    field_ids: Vec<i32>,
     bound_predicates: Option<BoundPredicate>,
     schema: SchemaRef,
     batch_size: Option<usize>,
@@ -204,6 +244,9 @@ impl TableScan {
         let mut manifest_evaluator_cache = ManifestEvaluatorCache::new();
         let mut expression_evaluator_cache = ExpressionEvaluatorCache::new();
 
+        let field_ids = self.field_ids.clone();
+        let bound_predicates = self.bound_predicates.clone();
+
         Ok(try_stream! {
             let manifest_list = context
                 .snapshot
@@ -272,6 +315,9 @@ impl TableScan {
                                 data_file_path: 
manifest_entry.data_file().file_path().to_string(),
                                 start: 0,
                                 length: manifest_entry.file_size_in_bytes(),
+                                project_field_ids: field_ids.clone(),
+                                predicate: bound_predicates.clone(),
+                                schema: context.schema.clone(),
                             });
                             yield scan_task?;
                         }
@@ -284,57 +330,12 @@ impl TableScan {
 
     /// Returns an [`ArrowRecordBatchStream`].
     pub async fn to_arrow(&self) -> Result<ArrowRecordBatchStream> {
-        let mut arrow_reader_builder =
-            ArrowReaderBuilder::new(self.file_io.clone(), self.schema.clone());
-
-        let mut field_ids = vec![];
-        for column_name in &self.column_names {
-            let field_id = 
self.schema.field_id_by_name(column_name).ok_or_else(|| {
-                Error::new(
-                    ErrorKind::DataInvalid,
-                    format!(
-                        "Column {} not found in table. Schema: {}",
-                        column_name, self.schema
-                    ),
-                )
-            })?;
-
-            let field = self.schema
-                .as_struct()
-                .field_by_id(field_id)
-                .ok_or_else(|| {
-                    Error::new(
-                        ErrorKind::FeatureUnsupported,
-                        format!(
-                            "Column {} is not a direct child of schema but a 
nested field, which is not supported now. Schema: {}",
-                            column_name, self.schema
-                        ),
-                    )
-                })?;
-
-            if !field.field_type.is_primitive() {
-                return Err(Error::new(
-                    ErrorKind::FeatureUnsupported,
-                    format!(
-                        "Column {} is not a primitive type. Schema: {}",
-                        column_name, self.schema
-                    ),
-                ));
-            }
-
-            field_ids.push(field_id as usize);
-        }
-
-        arrow_reader_builder = arrow_reader_builder.with_field_ids(field_ids);
+        let mut arrow_reader_builder = 
ArrowReaderBuilder::new(self.file_io.clone());
 
         if let Some(batch_size) = self.batch_size {
             arrow_reader_builder = 
arrow_reader_builder.with_batch_size(batch_size);
         }
 
-        if let Some(ref bound_predicates) = self.bound_predicates {
-            arrow_reader_builder = 
arrow_reader_builder.with_predicates(bound_predicates.clone());
-        }
-
         arrow_reader_builder.build().read(self.plan_files().await?)
     }
 
@@ -491,10 +492,12 @@ impl ExpressionEvaluatorCache {
 #[derive(Debug, Clone, Serialize, Deserialize)]
 pub struct FileScanTask {
     data_file_path: String,
-    #[allow(dead_code)]
     start: u64,
-    #[allow(dead_code)]
     length: u64,
+    project_field_ids: Vec<i32>,
+    #[serde(skip_serializing_if = "Option::is_none")]
+    predicate: Option<BoundPredicate>,
+    schema: SchemaRef,
 }
 
 impl FileScanTask {
@@ -502,21 +505,39 @@ impl FileScanTask {
     pub fn data_file_path(&self) -> &str {
         &self.data_file_path
     }
+
+    /// Returns the project field id of this file scan task.
+    pub fn project_field_ids(&self) -> &[i32] {
+        &self.project_field_ids
+    }
+
+    /// Returns the predicate of this file scan task.
+    pub fn predicate(&self) -> Option<&BoundPredicate> {
+        self.predicate.as_ref()
+    }
+
+    /// Returns the schema id of this file scan task.
+    pub fn schema(&self) -> &Schema {
+        &self.schema
+    }
 }
 
 #[cfg(test)]
 mod tests {
-    use crate::expr::Reference;
+    use crate::arrow::ArrowReaderBuilder;
+    use crate::expr::{BoundPredicate, Reference};
     use crate::io::{FileIO, OutputFile};
+    use crate::scan::FileScanTask;
     use crate::spec::{
         DataContentType, DataFileBuilder, DataFileFormat, Datum, 
FormatVersion, Literal, Manifest,
         ManifestContentType, ManifestEntry, ManifestListWriter, 
ManifestMetadata, ManifestStatus,
-        ManifestWriter, Struct, TableMetadata, EMPTY_SNAPSHOT_ID,
+        ManifestWriter, NestedField, PrimitiveType, Schema, Struct, 
TableMetadata, Type,
+        EMPTY_SNAPSHOT_ID,
     };
     use crate::table::Table;
     use crate::TableIdent;
     use arrow_array::{ArrayRef, Int64Array, RecordBatch, StringArray};
-    use futures::TryStreamExt;
+    use futures::{stream, TryStreamExt};
     use parquet::arrow::{ArrowWriter, PARQUET_FIELD_ID_META_KEY};
     use parquet::basic::Compression;
     use parquet::file::properties::WriterProperties;
@@ -865,6 +886,39 @@ mod tests {
         assert_eq!(int64_arr.value(0), 1);
     }
 
+    #[tokio::test]
+    async fn test_open_parquet_no_deletions_by_separate_reader() {
+        let mut fixture = TableTestFixture::new();
+        fixture.setup_manifest_files().await;
+
+        // Create table scan for current snapshot and plan files
+        let table_scan = fixture.table.scan().build().unwrap();
+
+        let mut plan_task: Vec<_> = table_scan
+            .plan_files()
+            .await
+            .unwrap()
+            .try_collect()
+            .await
+            .unwrap();
+        assert_eq!(plan_task.len(), 2);
+
+        let reader = 
ArrowReaderBuilder::new(fixture.table.file_io().clone()).build();
+        let batch_stream = reader
+            .clone()
+            .read(Box::pin(stream::iter(vec![Ok(plan_task.remove(0))])))
+            .unwrap();
+        let batche1: Vec<_> = batch_stream.try_collect().await.unwrap();
+
+        let reader = 
ArrowReaderBuilder::new(fixture.table.file_io().clone()).build();
+        let batch_stream = reader
+            .read(Box::pin(stream::iter(vec![Ok(plan_task.remove(0))])))
+            .unwrap();
+        let batche2: Vec<_> = batch_stream.try_collect().await.unwrap();
+
+        assert_eq!(batche1, batche2);
+    }
+
     #[tokio::test]
     async fn test_open_parquet_with_projection() {
         let mut fixture = TableTestFixture::new();
@@ -1134,4 +1188,51 @@ mod tests {
         let string_arr = col.as_any().downcast_ref::<StringArray>().unwrap();
         assert_eq!(string_arr.value(0), "Apache");
     }
+
+    #[test]
+    fn test_file_scan_task_serialize_deserialize() {
+        let test_fn = |task: FileScanTask| {
+            let serialized = serde_json::to_string(&task).unwrap();
+            let deserialized: FileScanTask = 
serde_json::from_str(&serialized).unwrap();
+
+            assert_eq!(task.data_file_path, deserialized.data_file_path);
+            assert_eq!(task.start, deserialized.start);
+            assert_eq!(task.length, deserialized.length);
+            assert_eq!(task.project_field_ids, deserialized.project_field_ids);
+            assert_eq!(task.predicate, deserialized.predicate);
+            assert_eq!(task.schema, deserialized.schema);
+        };
+
+        // without predicate
+        let schema = Arc::new(
+            Schema::builder()
+                .with_fields(vec![Arc::new(NestedField::required(
+                    1,
+                    "x",
+                    Type::Primitive(PrimitiveType::Binary),
+                ))])
+                .build()
+                .unwrap(),
+        );
+        let task = FileScanTask {
+            data_file_path: "data_file_path".to_string(),
+            start: 0,
+            length: 100,
+            project_field_ids: vec![1, 2, 3],
+            predicate: None,
+            schema: schema.clone(),
+        };
+        test_fn(task);
+
+        // with predicate
+        let task = FileScanTask {
+            data_file_path: "data_file_path".to_string(),
+            start: 0,
+            length: 100,
+            project_field_ids: vec![1, 2, 3],
+            predicate: Some(BoundPredicate::AlwaysTrue),
+            schema,
+        };
+        test_fn(task);
+    }
 }
diff --git a/crates/iceberg/src/spec/schema.rs 
b/crates/iceberg/src/spec/schema.rs
index 2f21f20..93edcd7 100644
--- a/crates/iceberg/src/spec/schema.rs
+++ b/crates/iceberg/src/spec/schema.rs
@@ -317,7 +317,7 @@ impl Schema {
 
     /// Returns [`schema_id`].
     #[inline]
-    pub fn schema_id(&self) -> i32 {
+    pub fn schema_id(&self) -> SchemaId {
         self.schema_id
     }
 
diff --git a/crates/iceberg/src/table.rs b/crates/iceberg/src/table.rs
index fd8bd28..c76b286 100644
--- a/crates/iceberg/src/table.rs
+++ b/crates/iceberg/src/table.rs
@@ -16,6 +16,7 @@
 // under the License.
 
 //! Table API for Apache Iceberg
+use crate::arrow::ArrowReaderBuilder;
 use crate::io::FileIO;
 use crate::scan::TableScanBuilder;
 use crate::spec::{TableMetadata, TableMetadataRef};
@@ -70,6 +71,11 @@ impl Table {
     pub fn readonly(&self) -> bool {
         self.readonly
     }
+
+    /// Create a reader for the table.
+    pub fn reader_builder(&self) -> ArrowReaderBuilder {
+        ArrowReaderBuilder::new(self.file_io.clone())
+    }
 }
 
 /// `StaticTable` is a read-only table struct that can be created from a 
metadata file or from `TableMetaData` without a catalog.
@@ -138,6 +144,11 @@ impl StaticTable {
     pub fn into_table(self) -> Table {
         self.0
     }
+
+    /// Create a reader for the table.
+    pub fn reader_builder(&self) -> ArrowReaderBuilder {
+        ArrowReaderBuilder::new(self.0.file_io.clone())
+    }
 }
 
 #[cfg(test)]

Reply via email to