This is an automated email from the ASF dual-hosted git repository.

liurenjie1024 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-rust.git


The following commit(s) were added to refs/heads/main by this push:
     new 11e36c0a feat: allow empty projection in table scan (#677)
11e36c0a is described below

commit 11e36c0ae635aac57471f82f9e7e0c12e587aa22
Author: sundyli <[email protected]>
AuthorDate: Fri Oct 25 17:40:20 2024 +0800

    feat: allow empty projection in table scan (#677)
    
    * fix: allow empty projection in scan
    
    * fix: allow empty projection in scan
    
    * fix: pub get_manifest_list
    
    * Update crates/iceberg/src/scan.rs
    
    Co-authored-by: Renjie Liu <[email protected]>
    
    * chore: remove pub
    
    ---------
    
    Co-authored-by: Renjie Liu <[email protected]>
---
 .../iceberg/src/arrow/record_batch_transformer.rs  | 29 ++++------
 crates/iceberg/src/scan.rs                         | 63 ++++++++++++++++------
 2 files changed, 57 insertions(+), 35 deletions(-)

diff --git a/crates/iceberg/src/arrow/record_batch_transformer.rs 
b/crates/iceberg/src/arrow/record_batch_transformer.rs
index 01ce9f0a..216e68ea 100644
--- a/crates/iceberg/src/arrow/record_batch_transformer.rs
+++ b/crates/iceberg/src/arrow/record_batch_transformer.rs
@@ -20,7 +20,7 @@ use std::sync::Arc;
 
 use arrow_array::{
     Array as ArrowArray, ArrayRef, BinaryArray, BooleanArray, Float32Array, 
Float64Array,
-    Int32Array, Int64Array, NullArray, RecordBatch, StringArray,
+    Int32Array, Int64Array, NullArray, RecordBatch, RecordBatchOptions, 
StringArray,
 };
 use arrow_cast::cast;
 use arrow_schema::{
@@ -124,19 +124,7 @@ impl RecordBatchTransformer {
         snapshot_schema: Arc<IcebergSchema>,
         projected_iceberg_field_ids: &[i32],
     ) -> Self {
-        let projected_iceberg_field_ids = if 
projected_iceberg_field_ids.is_empty() {
-            // If the list of field ids is empty, this indicates that we
-            // need to select all fields.
-            // Project all fields in table schema order
-            snapshot_schema
-                .as_struct()
-                .fields()
-                .iter()
-                .map(|field| field.id)
-                .collect()
-        } else {
-            projected_iceberg_field_ids.to_vec()
-        };
+        let projected_iceberg_field_ids = projected_iceberg_field_ids.to_vec();
 
         Self {
             snapshot_schema,
@@ -154,10 +142,15 @@ impl RecordBatchTransformer {
             Some(BatchTransform::Modify {
                 ref target_schema,
                 ref operations,
-            }) => RecordBatch::try_new(
-                target_schema.clone(),
-                self.transform_columns(record_batch.columns(), operations)?,
-            )?,
+            }) => {
+                let options =
+                    
RecordBatchOptions::default().with_row_count(Some(record_batch.num_rows()));
+                RecordBatch::try_new_with_options(
+                    target_schema.clone(),
+                    self.transform_columns(record_batch.columns(), 
operations)?,
+                    &options,
+                )?
+            }
             Some(BatchTransform::ModifySchema { target_schema }) => {
                 record_batch.with_schema(target_schema.clone())?
             }
diff --git a/crates/iceberg/src/scan.rs b/crates/iceberg/src/scan.rs
index b7fa62d6..d0355454 100644
--- a/crates/iceberg/src/scan.rs
+++ b/crates/iceberg/src/scan.rs
@@ -51,8 +51,8 @@ pub type ArrowRecordBatchStream = BoxStream<'static, 
Result<RecordBatch>>;
 /// Builder to create table scan.
 pub struct TableScanBuilder<'a> {
     table: &'a Table,
-    // Empty column names means to select all columns
-    column_names: Vec<String>,
+    // Defaults to none which means select all columns
+    column_names: Option<Vec<String>>,
     snapshot_id: Option<i64>,
     batch_size: Option<usize>,
     case_sensitive: bool,
@@ -70,7 +70,7 @@ impl<'a> TableScanBuilder<'a> {
 
         Self {
             table,
-            column_names: vec![],
+            column_names: None,
             snapshot_id: None,
             batch_size: None,
             case_sensitive: true,
@@ -106,16 +106,24 @@ impl<'a> TableScanBuilder<'a> {
 
     /// Select all columns.
     pub fn select_all(mut self) -> Self {
-        self.column_names.clear();
+        self.column_names = None;
+        self
+    }
+
+    /// Select empty columns.
+    pub fn select_empty(mut self) -> Self {
+        self.column_names = Some(vec![]);
         self
     }
 
     /// Select some columns of the table.
     pub fn select(mut self, column_names: impl IntoIterator<Item = impl 
ToString>) -> Self {
-        self.column_names = column_names
-            .into_iter()
-            .map(|item| item.to_string())
-            .collect();
+        self.column_names = Some(
+            column_names
+                .into_iter()
+                .map(|item| item.to_string())
+                .collect(),
+        );
         self
     }
 
@@ -205,8 +213,8 @@ impl<'a> TableScanBuilder<'a> {
         let schema = snapshot.schema(self.table.metadata())?;
 
         // Check that all column names exist in the schema.
-        if !self.column_names.is_empty() {
-            for column_name in &self.column_names {
+        if let Some(column_names) = self.column_names.as_ref() {
+            for column_name in column_names {
                 if schema.field_by_name(column_name).is_none() {
                     return Err(Error::new(
                         ErrorKind::DataInvalid,
@@ -220,7 +228,16 @@ impl<'a> TableScanBuilder<'a> {
         }
 
         let mut field_ids = vec![];
-        for column_name in &self.column_names {
+        let column_names = self.column_names.clone().unwrap_or_else(|| {
+            schema
+                .as_struct()
+                .fields()
+                .iter()
+                .map(|f| f.name.clone())
+                .collect()
+        });
+
+        for column_name in column_names.iter() {
             let field_id = schema.field_id_by_name(column_name).ok_or_else(|| {
                 Error::new(
                     ErrorKind::DataInvalid,
@@ -297,7 +314,7 @@ pub struct TableScan {
     plan_context: PlanContext,
     batch_size: Option<usize>,
     file_io: FileIO,
-    column_names: Vec<String>,
+    column_names: Option<Vec<String>>,
     /// The maximum number of manifest files that will be
     /// retrieved from [`FileIO`] concurrently
     concurrency_limit_manifest_files: usize,
@@ -409,9 +426,10 @@ impl TableScan {
     }
 
     /// Returns a reference to the column names of the table scan.
-    pub fn column_names(&self) -> &[String] {
-        &self.column_names
+    pub fn column_names(&self) -> Option<&[String]> {
+        self.column_names.as_deref()
     }
+
     /// Returns a reference to the snapshot of the table scan.
     pub fn snapshot(&self) -> &SnapshotRef {
         &self.plan_context.snapshot
@@ -1236,7 +1254,10 @@ mod tests {
         let table = TableTestFixture::new().table;
 
         let table_scan = table.scan().select(["x", "y"]).build().unwrap();
-        assert_eq!(vec!["x", "y"], table_scan.column_names);
+        assert_eq!(
+            Some(vec!["x".to_string(), "y".to_string()]),
+            table_scan.column_names
+        );
 
         let table_scan = table
             .scan()
@@ -1244,7 +1265,7 @@ mod tests {
             .select(["z"])
             .build()
             .unwrap();
-        assert_eq!(vec!["z"], table_scan.column_names);
+        assert_eq!(Some(vec!["z".to_string()]), table_scan.column_names);
     }
 
     #[test]
@@ -1252,7 +1273,7 @@ mod tests {
         let table = TableTestFixture::new().table;
 
         let table_scan = table.scan().select_all().build().unwrap();
-        assert!(table_scan.column_names.is_empty());
+        assert!(table_scan.column_names.is_none());
     }
 
     #[test]
@@ -1424,6 +1445,14 @@ mod tests {
         let col2 = batches[0].column_by_name("z").unwrap();
         let int64_arr = col2.as_any().downcast_ref::<Int64Array>().unwrap();
         assert_eq!(int64_arr.value(0), 3);
+
+        // test empty scan
+        let table_scan = fixture.table.scan().select_empty().build().unwrap();
+        let batch_stream = table_scan.to_arrow().await.unwrap();
+        let batches: Vec<_> = batch_stream.try_collect().await.unwrap();
+
+        assert_eq!(batches[0].num_columns(), 0);
+        assert_eq!(batches[0].num_rows(), 1024);
     }
 
     #[tokio::test]

Reply via email to