This is an automated email from the ASF dual-hosted git repository.

liurenjie1024 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-rust.git


The following commit(s) were added to refs/heads/main by this push:
     new 81df940  feat: Convert predicate to arrow filter and push down to 
parquet reader (#295)
81df940 is described below

commit 81df9406a4dd96de8ad82f5f9253738b44a3fe31
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Tue May 14 18:46:44 2024 -0700

    feat: Convert predicate to arrow filter and push down to parquet reader 
(#295)
    
    * feat: Convert predicate to arrow filter and push down to parquet reader
    
    * For review
    
    * Fix clippy
    
    * Change from vector of BoundPredicate to BoundPredicate
    
    * Add test for CollectFieldIdVisitor
    
    * Return projection_mask for leaf column
    
    * Update
    
    * For review
    
    * For review
    
    * For review
    
    * For review
    
    * More
    
    * fix
    
    * Fix clippy
    
    * More
    
    * Fix clippy
    
    * fix clippy
---
 Cargo.toml                           |   1 +
 crates/iceberg/Cargo.toml            |   1 +
 crates/iceberg/src/arrow/reader.rs   | 678 ++++++++++++++++++++++++++++++++++-
 crates/iceberg/src/arrow/schema.rs   |  26 +-
 crates/iceberg/src/expr/predicate.rs |   9 +
 crates/iceberg/src/scan.rs           | 198 +++++++++-
 6 files changed, 903 insertions(+), 10 deletions(-)

diff --git a/Cargo.toml b/Cargo.toml
index 57c3436..125cd0d 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -40,6 +40,7 @@ apache-avro = "0.16"
 array-init = "2"
 arrow-arith = { version = "51" }
 arrow-array = { version = "51" }
+arrow-ord = { version = "51" }
 arrow-schema = { version = "51" }
 arrow-select = { version = "51" }
 async-stream = "0.3.5"
diff --git a/crates/iceberg/Cargo.toml b/crates/iceberg/Cargo.toml
index 46f167b..95e0078 100644
--- a/crates/iceberg/Cargo.toml
+++ b/crates/iceberg/Cargo.toml
@@ -34,6 +34,7 @@ apache-avro = { workspace = true }
 array-init = { workspace = true }
 arrow-arith = { workspace = true }
 arrow-array = { workspace = true }
+arrow-ord = { workspace = true }
 arrow-schema = { workspace = true }
 arrow-select = { workspace = true }
 async-stream = { workspace = true }
diff --git a/crates/iceberg/src/arrow/reader.rs 
b/crates/iceberg/src/arrow/reader.rs
index fe5efac..391239c 100644
--- a/crates/iceberg/src/arrow/reader.rs
+++ b/crates/iceberg/src/arrow/reader.rs
@@ -17,25 +17,33 @@
 
 //! Parquet file data reader
 
-use arrow_schema::SchemaRef as ArrowSchemaRef;
+use crate::error::Result;
+use arrow_arith::boolean::{and, is_not_null, is_null, not, or};
+use arrow_array::{ArrayRef, BooleanArray, RecordBatch};
+use arrow_ord::cmp::{eq, gt, gt_eq, lt, lt_eq, neq};
+use arrow_schema::{ArrowError, DataType, SchemaRef as ArrowSchemaRef};
 use async_stream::try_stream;
 use bytes::Bytes;
+use fnv::FnvHashSet;
 use futures::future::BoxFuture;
 use futures::stream::StreamExt;
 use futures::{try_join, TryFutureExt};
+use parquet::arrow::arrow_reader::{ArrowPredicateFn, RowFilter};
 use parquet::arrow::async_reader::{AsyncFileReader, MetadataLoader};
 use parquet::arrow::{ParquetRecordBatchStreamBuilder, ProjectionMask, 
PARQUET_FIELD_ID_META_KEY};
 use parquet::file::metadata::ParquetMetaData;
-use parquet::schema::types::SchemaDescriptor;
-use std::collections::HashMap;
+use parquet::schema::types::{SchemaDescriptor, Type as ParquetType};
+use std::collections::{HashMap, HashSet};
 use std::ops::Range;
 use std::str::FromStr;
 use std::sync::Arc;
 
-use crate::arrow::arrow_schema_to_schema;
+use crate::arrow::{arrow_schema_to_schema, get_arrow_datum};
+use crate::expr::visitors::bound_predicate_visitor::{visit, 
BoundPredicateVisitor};
+use crate::expr::{BoundPredicate, BoundReference};
 use crate::io::{FileIO, FileMetadata, FileRead};
 use crate::scan::{ArrowRecordBatchStream, FileScanTaskStream};
-use crate::spec::SchemaRef;
+use crate::spec::{Datum, SchemaRef};
 use crate::{Error, ErrorKind};
 
 /// Builder to create ArrowReader
@@ -44,6 +52,7 @@ pub struct ArrowReaderBuilder {
     field_ids: Vec<usize>,
     file_io: FileIO,
     schema: SchemaRef,
+    predicates: Option<BoundPredicate>,
 }
 
 impl ArrowReaderBuilder {
@@ -54,6 +63,7 @@ impl ArrowReaderBuilder {
             field_ids: vec![],
             file_io,
             schema,
+            predicates: None,
         }
     }
 
@@ -70,6 +80,12 @@ impl ArrowReaderBuilder {
         self
     }
 
+    /// Sets the predicates to apply to the scan.
+    pub fn with_predicates(mut self, predicates: BoundPredicate) -> Self {
+        self.predicates = Some(predicates);
+        self
+    }
+
     /// Build the ArrowReader.
     pub fn build(self) -> ArrowReader {
         ArrowReader {
@@ -77,6 +93,7 @@ impl ArrowReaderBuilder {
             field_ids: self.field_ids,
             schema: self.schema,
             file_io: self.file_io,
+            predicates: self.predicates,
         }
     }
 }
@@ -88,6 +105,7 @@ pub struct ArrowReader {
     #[allow(dead_code)]
     schema: SchemaRef,
     file_io: FileIO,
+    predicates: Option<BoundPredicate>,
 }
 
 impl ArrowReader {
@@ -96,6 +114,14 @@ impl ArrowReader {
     pub fn read(self, mut tasks: FileScanTaskStream) -> 
crate::Result<ArrowRecordBatchStream> {
         let file_io = self.file_io.clone();
 
+        // Collect Parquet column indices from field ids
+        let mut collector = CollectFieldIdVisitor {
+            field_ids: HashSet::default(),
+        };
+        if let Some(predicates) = &self.predicates {
+            visit(&mut collector, predicates)?;
+        }
+
         Ok(try_stream! {
             while let Some(Ok(task)) = tasks.next().await {
                 let parquet_file = file_io
@@ -111,6 +137,13 @@ impl ArrowReader {
                 let projection_mask = 
self.get_arrow_projection_mask(parquet_schema, arrow_schema)?;
                 batch_stream_builder = 
batch_stream_builder.with_projection(projection_mask);
 
+                let parquet_schema = batch_stream_builder.parquet_schema();
+                let row_filter = self.get_row_filter(parquet_schema, 
&collector)?;
+
+                if let Some(row_filter) = row_filter {
+                    batch_stream_builder = 
batch_stream_builder.with_row_filter(row_filter);
+                }
+
                 if let Some(batch_size) = self.batch_size {
                     batch_stream_builder = 
batch_stream_builder.with_batch_size(batch_size);
                 }
@@ -193,6 +226,558 @@ impl ArrowReader {
             Ok(ProjectionMask::leaves(parquet_schema, indices))
         }
     }
+
+    fn get_row_filter(
+        &self,
+        parquet_schema: &SchemaDescriptor,
+        collector: &CollectFieldIdVisitor,
+    ) -> Result<Option<RowFilter>> {
+        if let Some(predicates) = &self.predicates {
+            let field_id_map = build_field_id_map(parquet_schema)?;
+
+            // Collect Parquet column indices from field ids.
+            // If the field id is not found in Parquet schema, it will be 
ignored due to schema evolution.
+            let mut column_indices = collector
+                .field_ids
+                .iter()
+                .filter_map(|field_id| field_id_map.get(field_id).cloned())
+                .collect::<Vec<_>>();
+
+            column_indices.sort();
+
+            // The converter that converts `BoundPredicates` to 
`ArrowPredicates`
+            let mut converter = PredicateConverter {
+                parquet_schema,
+                column_map: &field_id_map,
+                column_indices: &column_indices,
+            };
+
+            // After collecting required leaf column indices used in the 
predicate,
+            // creates the projection mask for the Arrow predicates.
+            let projection_mask = ProjectionMask::leaves(parquet_schema, 
column_indices.clone());
+            let predicate_func = visit(&mut converter, predicates)?;
+            let arrow_predicate = ArrowPredicateFn::new(projection_mask, 
predicate_func);
+            Ok(Some(RowFilter::new(vec![Box::new(arrow_predicate)])))
+        } else {
+            Ok(None)
+        }
+    }
+}
+
+/// Build the map of field id to Parquet column index in the schema.
+fn build_field_id_map(parquet_schema: &SchemaDescriptor) -> 
Result<HashMap<i32, usize>> {
+    let mut column_map = HashMap::new();
+    for (idx, field) in parquet_schema.columns().iter().enumerate() {
+        let field_type = field.self_type();
+        match field_type {
+            ParquetType::PrimitiveType { basic_info, .. } => {
+                if !basic_info.has_id() {
+                    return Err(Error::new(
+                        ErrorKind::DataInvalid,
+                        format!(
+                            "Leave column idx: {}, name: {}, type {:?} in 
schema doesn't have field id",
+                            idx,
+                            basic_info.name(),
+                            field_type
+                        ),
+                    ));
+                }
+                column_map.insert(basic_info.id(), idx);
+            }
+            ParquetType::GroupType { .. } => {
+                return Err(Error::new(
+                    ErrorKind::DataInvalid,
+                    format!(
+                        "Leave column in schema should be primitive type but 
got {:?}",
+                        field_type
+                    ),
+                ));
+            }
+        };
+    }
+
+    Ok(column_map)
+}
+
+/// A visitor to collect field ids from bound predicates.
+struct CollectFieldIdVisitor {
+    field_ids: HashSet<i32>,
+}
+
+impl BoundPredicateVisitor for CollectFieldIdVisitor {
+    type T = ();
+
+    fn always_true(&mut self) -> Result<()> {
+        Ok(())
+    }
+
+    fn always_false(&mut self) -> Result<()> {
+        Ok(())
+    }
+
+    fn and(&mut self, _lhs: (), _rhs: ()) -> Result<()> {
+        Ok(())
+    }
+
+    fn or(&mut self, _lhs: (), _rhs: ()) -> Result<()> {
+        Ok(())
+    }
+
+    fn not(&mut self, _inner: ()) -> Result<()> {
+        Ok(())
+    }
+
+    fn is_null(&mut self, reference: &BoundReference, _predicate: 
&BoundPredicate) -> Result<()> {
+        self.field_ids.insert(reference.field().id);
+        Ok(())
+    }
+
+    fn not_null(&mut self, reference: &BoundReference, _predicate: 
&BoundPredicate) -> Result<()> {
+        self.field_ids.insert(reference.field().id);
+        Ok(())
+    }
+
+    fn is_nan(&mut self, reference: &BoundReference, _predicate: 
&BoundPredicate) -> Result<()> {
+        self.field_ids.insert(reference.field().id);
+        Ok(())
+    }
+
+    fn not_nan(&mut self, reference: &BoundReference, _predicate: 
&BoundPredicate) -> Result<()> {
+        self.field_ids.insert(reference.field().id);
+        Ok(())
+    }
+
+    fn less_than(
+        &mut self,
+        reference: &BoundReference,
+        _literal: &Datum,
+        _predicate: &BoundPredicate,
+    ) -> Result<()> {
+        self.field_ids.insert(reference.field().id);
+        Ok(())
+    }
+
+    fn less_than_or_eq(
+        &mut self,
+        reference: &BoundReference,
+        _literal: &Datum,
+        _predicate: &BoundPredicate,
+    ) -> Result<()> {
+        self.field_ids.insert(reference.field().id);
+        Ok(())
+    }
+
+    fn greater_than(
+        &mut self,
+        reference: &BoundReference,
+        _literal: &Datum,
+        _predicate: &BoundPredicate,
+    ) -> Result<()> {
+        self.field_ids.insert(reference.field().id);
+        Ok(())
+    }
+
+    fn greater_than_or_eq(
+        &mut self,
+        reference: &BoundReference,
+        _literal: &Datum,
+        _predicate: &BoundPredicate,
+    ) -> Result<()> {
+        self.field_ids.insert(reference.field().id);
+        Ok(())
+    }
+
+    fn eq(
+        &mut self,
+        reference: &BoundReference,
+        _literal: &Datum,
+        _predicate: &BoundPredicate,
+    ) -> Result<()> {
+        self.field_ids.insert(reference.field().id);
+        Ok(())
+    }
+
+    fn not_eq(
+        &mut self,
+        reference: &BoundReference,
+        _literal: &Datum,
+        _predicate: &BoundPredicate,
+    ) -> Result<()> {
+        self.field_ids.insert(reference.field().id);
+        Ok(())
+    }
+
+    fn starts_with(
+        &mut self,
+        reference: &BoundReference,
+        _literal: &Datum,
+        _predicate: &BoundPredicate,
+    ) -> Result<()> {
+        self.field_ids.insert(reference.field().id);
+        Ok(())
+    }
+
+    fn not_starts_with(
+        &mut self,
+        reference: &BoundReference,
+        _literal: &Datum,
+        _predicate: &BoundPredicate,
+    ) -> Result<()> {
+        self.field_ids.insert(reference.field().id);
+        Ok(())
+    }
+
+    fn r#in(
+        &mut self,
+        reference: &BoundReference,
+        _literals: &FnvHashSet<Datum>,
+        _predicate: &BoundPredicate,
+    ) -> Result<()> {
+        self.field_ids.insert(reference.field().id);
+        Ok(())
+    }
+
+    fn not_in(
+        &mut self,
+        reference: &BoundReference,
+        _literals: &FnvHashSet<Datum>,
+        _predicate: &BoundPredicate,
+    ) -> Result<()> {
+        self.field_ids.insert(reference.field().id);
+        Ok(())
+    }
+}
+
+/// A visitor to convert Iceberg bound predicates to Arrow predicates.
+struct PredicateConverter<'a> {
+    /// The Parquet schema descriptor.
+    pub parquet_schema: &'a SchemaDescriptor,
+    /// The map between field id and leaf column index in Parquet schema.
+    pub column_map: &'a HashMap<i32, usize>,
+    /// The required column indices in Parquet schema for the predicates.
+    pub column_indices: &'a Vec<usize>,
+}
+
+impl PredicateConverter<'_> {
+    /// When visiting a bound reference, we return index of the leaf column in 
the
+    /// required column indices which is used to project the column in the 
record batch.
+    /// Return None if the field id is not found in the column map, which is 
possible
+    /// due to schema evolution.
+    fn bound_reference(&mut self, reference: &BoundReference) -> 
Result<Option<usize>> {
+        // The leaf column's index in Parquet schema.
+        if let Some(column_idx) = self.column_map.get(&reference.field().id) {
+            if self.parquet_schema.get_column_root_idx(*column_idx) != 
*column_idx {
+                return Err(Error::new(
+                    ErrorKind::DataInvalid,
+                    format!(
+                        "Leave column `{}` in predicates isn't a root column 
in Parquet schema.",
+                        reference.field().name
+                    ),
+                ));
+            }
+
+            // The leaf column's index in the required column indices.
+            let index = self
+                .column_indices
+                .iter()
+                .position(|&idx| idx == 
*column_idx).ok_or(Error::new(ErrorKind::DataInvalid, format!(
+                    "Leave column `{}` in predicates cannot be found in the 
required column indices.",
+                    reference.field().name
+                )))?;
+
+            Ok(Some(index))
+        } else {
+            Ok(None)
+        }
+    }
+
+    /// Build an Arrow predicate that always returns true.
+    fn build_always_true(&self) -> Result<Box<PredicateResult>> {
+        Ok(Box::new(|batch| {
+            Ok(BooleanArray::from(vec![true; batch.num_rows()]))
+        }))
+    }
+
+    /// Build an Arrow predicate that always returns false.
+    fn build_always_false(&self) -> Result<Box<PredicateResult>> {
+        Ok(Box::new(|batch| {
+            Ok(BooleanArray::from(vec![false; batch.num_rows()]))
+        }))
+    }
+}
+
+/// Gets the leaf column from the record batch for the required column index. 
Only
+/// supports top-level columns for now.
+fn project_column(
+    batch: &RecordBatch,
+    column_idx: usize,
+) -> std::result::Result<ArrayRef, ArrowError> {
+    let column = batch.column(column_idx);
+
+    match column.data_type() {
+        DataType::Struct(_) => Err(ArrowError::SchemaError(
+            "Does not support struct column yet.".to_string(),
+        )),
+        _ => Ok(column.clone()),
+    }
+}
+
+type PredicateResult =
+    dyn FnMut(RecordBatch) -> std::result::Result<BooleanArray, ArrowError> + 
Send + 'static;
+
+impl<'a> BoundPredicateVisitor for PredicateConverter<'a> {
+    type T = Box<PredicateResult>;
+
+    fn always_true(&mut self) -> Result<Box<PredicateResult>> {
+        self.build_always_true()
+    }
+
+    fn always_false(&mut self) -> Result<Box<PredicateResult>> {
+        self.build_always_false()
+    }
+
+    fn and(
+        &mut self,
+        mut lhs: Box<PredicateResult>,
+        mut rhs: Box<PredicateResult>,
+    ) -> Result<Box<PredicateResult>> {
+        Ok(Box::new(move |batch| {
+            let left = lhs(batch.clone())?;
+            let right = rhs(batch)?;
+            and(&left, &right)
+        }))
+    }
+
+    fn or(
+        &mut self,
+        mut lhs: Box<PredicateResult>,
+        mut rhs: Box<PredicateResult>,
+    ) -> Result<Box<PredicateResult>> {
+        Ok(Box::new(move |batch| {
+            let left = lhs(batch.clone())?;
+            let right = rhs(batch)?;
+            or(&left, &right)
+        }))
+    }
+
+    fn not(&mut self, mut inner: Box<PredicateResult>) -> 
Result<Box<PredicateResult>> {
+        Ok(Box::new(move |batch| {
+            let pred_ret = inner(batch)?;
+            not(&pred_ret)
+        }))
+    }
+
+    fn is_null(
+        &mut self,
+        reference: &BoundReference,
+        _predicate: &BoundPredicate,
+    ) -> Result<Box<PredicateResult>> {
+        if let Some(idx) = self.bound_reference(reference)? {
+            Ok(Box::new(move |batch| {
+                let column = project_column(&batch, idx)?;
+                is_null(&column)
+            }))
+        } else {
+            // A missing column, treating it as null.
+            self.build_always_true()
+        }
+    }
+
+    fn not_null(
+        &mut self,
+        reference: &BoundReference,
+        _predicate: &BoundPredicate,
+    ) -> Result<Box<PredicateResult>> {
+        if let Some(idx) = self.bound_reference(reference)? {
+            Ok(Box::new(move |batch| {
+                let column = project_column(&batch, idx)?;
+                is_not_null(&column)
+            }))
+        } else {
+            // A missing column, treating it as null.
+            self.build_always_false()
+        }
+    }
+
+    fn is_nan(
+        &mut self,
+        reference: &BoundReference,
+        _predicate: &BoundPredicate,
+    ) -> Result<Box<PredicateResult>> {
+        if self.bound_reference(reference)?.is_some() {
+            self.build_always_true()
+        } else {
+            // A missing column, treating it as null.
+            self.build_always_false()
+        }
+    }
+
+    fn not_nan(
+        &mut self,
+        reference: &BoundReference,
+        _predicate: &BoundPredicate,
+    ) -> Result<Box<PredicateResult>> {
+        if self.bound_reference(reference)?.is_some() {
+            self.build_always_false()
+        } else {
+            // A missing column, treating it as null.
+            self.build_always_true()
+        }
+    }
+
+    fn less_than(
+        &mut self,
+        reference: &BoundReference,
+        literal: &Datum,
+        _predicate: &BoundPredicate,
+    ) -> Result<Box<PredicateResult>> {
+        if let Some(idx) = self.bound_reference(reference)? {
+            let literal = get_arrow_datum(literal)?;
+
+            Ok(Box::new(move |batch| {
+                let left = project_column(&batch, idx)?;
+                lt(&left, literal.as_ref())
+            }))
+        } else {
+            // A missing column, treating it as null.
+            self.build_always_true()
+        }
+    }
+
+    fn less_than_or_eq(
+        &mut self,
+        reference: &BoundReference,
+        literal: &Datum,
+        _predicate: &BoundPredicate,
+    ) -> Result<Box<PredicateResult>> {
+        if let Some(idx) = self.bound_reference(reference)? {
+            let literal = get_arrow_datum(literal)?;
+
+            Ok(Box::new(move |batch| {
+                let left = project_column(&batch, idx)?;
+                lt_eq(&left, literal.as_ref())
+            }))
+        } else {
+            // A missing column, treating it as null.
+            self.build_always_true()
+        }
+    }
+
+    fn greater_than(
+        &mut self,
+        reference: &BoundReference,
+        literal: &Datum,
+        _predicate: &BoundPredicate,
+    ) -> Result<Box<PredicateResult>> {
+        if let Some(idx) = self.bound_reference(reference)? {
+            let literal = get_arrow_datum(literal)?;
+
+            Ok(Box::new(move |batch| {
+                let left = project_column(&batch, idx)?;
+                gt(&left, literal.as_ref())
+            }))
+        } else {
+            // A missing column, treating it as null.
+            self.build_always_false()
+        }
+    }
+
+    fn greater_than_or_eq(
+        &mut self,
+        reference: &BoundReference,
+        literal: &Datum,
+        _predicate: &BoundPredicate,
+    ) -> Result<Box<PredicateResult>> {
+        if let Some(idx) = self.bound_reference(reference)? {
+            let literal = get_arrow_datum(literal)?;
+
+            Ok(Box::new(move |batch| {
+                let left = project_column(&batch, idx)?;
+                gt_eq(&left, literal.as_ref())
+            }))
+        } else {
+            // A missing column, treating it as null.
+            self.build_always_false()
+        }
+    }
+
+    fn eq(
+        &mut self,
+        reference: &BoundReference,
+        literal: &Datum,
+        _predicate: &BoundPredicate,
+    ) -> Result<Box<PredicateResult>> {
+        if let Some(idx) = self.bound_reference(reference)? {
+            let literal = get_arrow_datum(literal)?;
+
+            Ok(Box::new(move |batch| {
+                let left = project_column(&batch, idx)?;
+                eq(&left, literal.as_ref())
+            }))
+        } else {
+            // A missing column, treating it as null.
+            self.build_always_false()
+        }
+    }
+
+    fn not_eq(
+        &mut self,
+        reference: &BoundReference,
+        literal: &Datum,
+        _predicate: &BoundPredicate,
+    ) -> Result<Box<PredicateResult>> {
+        if let Some(idx) = self.bound_reference(reference)? {
+            let literal = get_arrow_datum(literal)?;
+
+            Ok(Box::new(move |batch| {
+                let left = project_column(&batch, idx)?;
+                neq(&left, literal.as_ref())
+            }))
+        } else {
+            // A missing column, treating it as null.
+            self.build_always_false()
+        }
+    }
+
+    fn starts_with(
+        &mut self,
+        _reference: &BoundReference,
+        _literal: &Datum,
+        _predicate: &BoundPredicate,
+    ) -> Result<Box<PredicateResult>> {
+        // TODO: Implement starts_with
+        self.build_always_true()
+    }
+
+    fn not_starts_with(
+        &mut self,
+        _reference: &BoundReference,
+        _literal: &Datum,
+        _predicate: &BoundPredicate,
+    ) -> Result<Box<PredicateResult>> {
+        // TODO: Implement not_starts_with
+        self.build_always_true()
+    }
+
+    fn r#in(
+        &mut self,
+        _reference: &BoundReference,
+        _literals: &FnvHashSet<Datum>,
+        _predicate: &BoundPredicate,
+    ) -> Result<Box<PredicateResult>> {
+        // TODO: Implement in
+        self.build_always_true()
+    }
+
+    fn not_in(
+        &mut self,
+        _reference: &BoundReference,
+        _literals: &FnvHashSet<Datum>,
+        _predicate: &BoundPredicate,
+    ) -> Result<Box<PredicateResult>> {
+        // TODO: Implement not_in
+        self.build_always_true()
+    }
 }
 
 /// ArrowFileReader is a wrapper around a FileRead that impls parquets 
AsyncFileReader.
@@ -234,3 +819,86 @@ impl<R: FileRead> AsyncFileReader for ArrowFileReader<R> {
         })
     }
 }
+
+#[cfg(test)]
+mod tests {
+    use crate::arrow::reader::CollectFieldIdVisitor;
+    use crate::expr::visitors::bound_predicate_visitor::visit;
+    use crate::expr::{Bind, Reference};
+    use crate::spec::{NestedField, PrimitiveType, Schema, SchemaRef, Type};
+    use std::collections::HashSet;
+    use std::sync::Arc;
+
+    fn table_schema_simple() -> SchemaRef {
+        Arc::new(
+            Schema::builder()
+                .with_schema_id(1)
+                .with_identifier_field_ids(vec![2])
+                .with_fields(vec![
+                    NestedField::optional(1, "foo", 
Type::Primitive(PrimitiveType::String)).into(),
+                    NestedField::required(2, "bar", 
Type::Primitive(PrimitiveType::Int)).into(),
+                    NestedField::optional(3, "baz", 
Type::Primitive(PrimitiveType::Boolean)).into(),
+                    NestedField::optional(4, "qux", 
Type::Primitive(PrimitiveType::Float)).into(),
+                ])
+                .build()
+                .unwrap(),
+        )
+    }
+
+    #[test]
+    fn test_collect_field_id() {
+        let schema = table_schema_simple();
+        let expr = Reference::new("qux").is_null();
+        let bound_expr = expr.bind(schema, true).unwrap();
+
+        let mut visitor = CollectFieldIdVisitor {
+            field_ids: HashSet::default(),
+        };
+        visit(&mut visitor, &bound_expr).unwrap();
+
+        let mut expected = HashSet::default();
+        expected.insert(4_i32);
+
+        assert_eq!(visitor.field_ids, expected);
+    }
+
+    #[test]
+    fn test_collect_field_id_with_and() {
+        let schema = table_schema_simple();
+        let expr = Reference::new("qux")
+            .is_null()
+            .and(Reference::new("baz").is_null());
+        let bound_expr = expr.bind(schema, true).unwrap();
+
+        let mut visitor = CollectFieldIdVisitor {
+            field_ids: HashSet::default(),
+        };
+        visit(&mut visitor, &bound_expr).unwrap();
+
+        let mut expected = HashSet::default();
+        expected.insert(4_i32);
+        expected.insert(3);
+
+        assert_eq!(visitor.field_ids, expected);
+    }
+
+    #[test]
+    fn test_collect_field_id_with_or() {
+        let schema = table_schema_simple();
+        let expr = Reference::new("qux")
+            .is_null()
+            .or(Reference::new("baz").is_null());
+        let bound_expr = expr.bind(schema, true).unwrap();
+
+        let mut visitor = CollectFieldIdVisitor {
+            field_ids: HashSet::default(),
+        };
+        visit(&mut visitor, &bound_expr).unwrap();
+
+        let mut expected = HashSet::default();
+        expected.insert(4_i32);
+        expected.insert(3);
+
+        assert_eq!(visitor.field_ids, expected);
+    }
+}
diff --git a/crates/iceberg/src/arrow/schema.rs 
b/crates/iceberg/src/arrow/schema.rs
index c7e8700..172d4bb 100644
--- a/crates/iceberg/src/arrow/schema.rs
+++ b/crates/iceberg/src/arrow/schema.rs
@@ -19,12 +19,16 @@
 
 use crate::error::Result;
 use crate::spec::{
-    ListType, MapType, NestedField, NestedFieldRef, PrimitiveType, Schema, 
SchemaVisitor,
-    StructType, Type,
+    Datum, ListType, MapType, NestedField, NestedFieldRef, PrimitiveLiteral, 
PrimitiveType, Schema,
+    SchemaVisitor, StructType, Type,
 };
 use crate::{Error, ErrorKind};
 use arrow_array::types::{validate_decimal_precision_and_scale, Decimal128Type};
+use arrow_array::{
+    BooleanArray, Datum as ArrowDatum, Float32Array, Float64Array, Int32Array, 
Int64Array,
+};
 use arrow_schema::{DataType, Field, Fields, Schema as ArrowSchema, TimeUnit};
+use bitvec::macros::internal::funty::Fundamental;
 use parquet::arrow::PARQUET_FIELD_ID_META_KEY;
 use rust_decimal::prelude::ToPrimitive;
 use std::collections::HashMap;
@@ -593,6 +597,24 @@ pub fn schema_to_arrow_schema(schema: 
&crate::spec::Schema) -> crate::Result<Arr
     }
 }
 
+/// Convert Iceberg Datum to Arrow Datum.
+pub(crate) fn get_arrow_datum(datum: &Datum) -> Result<Box<dyn ArrowDatum + 
Send>> {
+    match datum.literal() {
+        PrimitiveLiteral::Boolean(value) => 
Ok(Box::new(BooleanArray::new_scalar(*value))),
+        PrimitiveLiteral::Int(value) => 
Ok(Box::new(Int32Array::new_scalar(*value))),
+        PrimitiveLiteral::Long(value) => 
Ok(Box::new(Int64Array::new_scalar(*value))),
+        PrimitiveLiteral::Float(value) => 
Ok(Box::new(Float32Array::new_scalar(value.as_f32()))),
+        PrimitiveLiteral::Double(value) => 
Ok(Box::new(Float64Array::new_scalar(value.as_f64()))),
+        l => Err(Error::new(
+            ErrorKind::FeatureUnsupported,
+            format!(
+                "Converting datum from type {:?} to arrow not supported yet.",
+                l
+            ),
+        )),
+    }
+}
+
 impl TryFrom<&ArrowSchema> for crate::spec::Schema {
     type Error = Error;
 
diff --git a/crates/iceberg/src/expr/predicate.rs 
b/crates/iceberg/src/expr/predicate.rs
index 1457d5a..158ab13 100644
--- a/crates/iceberg/src/expr/predicate.rs
+++ b/crates/iceberg/src/expr/predicate.rs
@@ -116,10 +116,13 @@ impl<T> UnaryExpression<T> {
         debug_assert!(op.is_unary());
         Self { op, term }
     }
+
+    /// Return the operator of this predicate.
     pub(crate) fn op(&self) -> PredicateOperator {
         self.op
     }
 
+    /// Return the term of this predicate.
     pub(crate) fn term(&self) -> &T {
         &self.term
     }
@@ -155,10 +158,13 @@ impl<T> BinaryExpression<T> {
     pub(crate) fn op(&self) -> PredicateOperator {
         self.op
     }
+
+    /// Return the literal of this predicate.
     pub(crate) fn literal(&self) -> &Datum {
         &self.literal
     }
 
+    /// Return the term of this predicate.
     pub(crate) fn term(&self) -> &T {
         &self.term
     }
@@ -210,13 +216,16 @@ impl<T> SetExpression<T> {
         Self { op, term, literals }
     }
 
+    /// Return the operator of this predicate.
     pub(crate) fn op(&self) -> PredicateOperator {
         self.op
     }
+
     pub(crate) fn literals(&self) -> &FnvHashSet<Datum> {
         &self.literals
     }
 
+    /// Return the term of this predicate.
     pub(crate) fn term(&self) -> &T {
         &self.term
     }
diff --git a/crates/iceberg/src/scan.rs b/crates/iceberg/src/scan.rs
index c2a5e1b..906bf96 100644
--- a/crates/iceberg/src/scan.rs
+++ b/crates/iceberg/src/scan.rs
@@ -46,6 +46,7 @@ pub struct TableScanBuilder<'a> {
     table: &'a Table,
     // Empty column names means to select all columns
     column_names: Vec<String>,
+    predicates: Option<Predicate>,
     snapshot_id: Option<i64>,
     batch_size: Option<usize>,
     case_sensitive: bool,
@@ -57,6 +58,7 @@ impl<'a> TableScanBuilder<'a> {
         Self {
             table,
             column_names: vec![],
+            predicates: None,
             snapshot_id: None,
             batch_size: None,
             case_sensitive: true,
@@ -91,6 +93,12 @@ impl<'a> TableScanBuilder<'a> {
         self
     }
 
+    /// Add a predicate to the scan. The scan will only return rows that match 
the predicate.
+    pub fn filter(mut self, predicate: Predicate) -> Self {
+        self.predicates = Some(predicate);
+        self
+    }
+
     /// Select some columns of the table.
     pub fn select(mut self, column_names: impl IntoIterator<Item = impl 
ToString>) -> Self {
         self.column_names = column_names
@@ -150,11 +158,18 @@ impl<'a> TableScanBuilder<'a> {
             }
         }
 
+        let bound_predicates = if let Some(ref predicates) = self.predicates {
+            Some(predicates.bind(schema.clone(), true)?)
+        } else {
+            None
+        };
+
         Ok(TableScan {
             snapshot,
             file_io: self.table.file_io().clone(),
             table_metadata: self.table.metadata_ref(),
             column_names: self.column_names,
+            bound_predicates,
             schema,
             batch_size: self.batch_size,
             case_sensitive: self.case_sensitive,
@@ -171,6 +186,7 @@ pub struct TableScan {
     table_metadata: TableMetadataRef,
     file_io: FileIO,
     column_names: Vec<String>,
+    bound_predicates: Option<BoundPredicate>,
     schema: SchemaRef,
     batch_size: Option<usize>,
     case_sensitive: bool,
@@ -300,6 +316,10 @@ impl TableScan {
             arrow_reader_builder = 
arrow_reader_builder.with_batch_size(batch_size);
         }
 
+        if let Some(ref bound_predicates) = self.bound_predicates {
+            arrow_reader_builder = 
arrow_reader_builder.with_predicates(bound_predicates.clone());
+        }
+
         arrow_reader_builder.build().read(self.plan_files().await?)
     }
 
@@ -449,9 +469,10 @@ impl FileScanTask {
 
 #[cfg(test)]
 mod tests {
+    use crate::expr::Reference;
     use crate::io::{FileIO, OutputFile};
     use crate::spec::{
-        DataContentType, DataFileBuilder, DataFileFormat, FormatVersion, 
Literal, Manifest,
+        DataContentType, DataFileBuilder, DataFileFormat, Datum, 
FormatVersion, Literal, Manifest,
         ManifestContentType, ManifestEntry, ManifestListWriter, 
ManifestMetadata, ManifestStatus,
         ManifestWriter, Struct, TableMetadata, EMPTY_SNAPSHOT_ID,
     };
@@ -642,9 +663,23 @@ mod tests {
                 ];
                 Arc::new(arrow_schema::Schema::new(fields))
             };
+            // 3 columns:
+            // x: [1, 1, 1, 1, ...]
             let col1 = Arc::new(Int64Array::from_iter_values(vec![1; 1024])) 
as ArrayRef;
-            let col2 = Arc::new(Int64Array::from_iter_values(vec![2; 1024])) 
as ArrayRef;
-            let col3 = Arc::new(Int64Array::from_iter_values(vec![3; 1024])) 
as ArrayRef;
+
+            let mut values = vec![2; 512];
+            values.append(vec![3; 200].as_mut());
+            values.append(vec![4; 300].as_mut());
+            values.append(vec![5; 12].as_mut());
+
+            // y: [2, 2, 2, 2, ..., 3, 3, 3, 3, ..., 4, 4, 4, 4, ..., 5, 5, 5, 
5]
+            let col2 = Arc::new(Int64Array::from_iter_values(values)) as 
ArrayRef;
+
+            let mut values = vec![3; 512];
+            values.append(vec![4; 512].as_mut());
+
+            // z: [3, 3, 3, 3, ..., 4, 4, 4, 4]
+            let col3 = Arc::new(Int64Array::from_iter_values(values)) as 
ArrayRef;
             let to_write = RecordBatch::try_new(schema.clone(), vec![col1, 
col2, col3]).unwrap();
 
             // Write the Parquet files
@@ -803,4 +838,161 @@ mod tests {
         let int64_arr = col2.as_any().downcast_ref::<Int64Array>().unwrap();
         assert_eq!(int64_arr.value(0), 3);
     }
+
+    #[tokio::test]
+    async fn test_filter_on_arrow_lt() {
+        let mut fixture = TableTestFixture::new();
+        fixture.setup_manifest_files().await;
+
+        // Filter: y < 3
+        let mut builder = fixture.table.scan();
+        let predicate = Reference::new("y").less_than(Datum::long(3));
+        builder = builder.filter(predicate);
+        let table_scan = builder.build().unwrap();
+
+        let batch_stream = table_scan.to_arrow().await.unwrap();
+
+        let batches: Vec<_> = batch_stream.try_collect().await.unwrap();
+
+        assert_eq!(batches[0].num_rows(), 512);
+
+        let col = batches[0].column_by_name("x").unwrap();
+        let int64_arr = col.as_any().downcast_ref::<Int64Array>().unwrap();
+        assert_eq!(int64_arr.value(0), 1);
+
+        let col = batches[0].column_by_name("y").unwrap();
+        let int64_arr = col.as_any().downcast_ref::<Int64Array>().unwrap();
+        assert_eq!(int64_arr.value(0), 2);
+    }
+
+    #[tokio::test]
+    async fn test_filter_on_arrow_gt_eq() {
+        let mut fixture = TableTestFixture::new();
+        fixture.setup_manifest_files().await;
+
+        // Filter: y >= 5
+        let mut builder = fixture.table.scan();
+        let predicate = 
Reference::new("y").greater_than_or_equal_to(Datum::long(5));
+        builder = builder.filter(predicate);
+        let table_scan = builder.build().unwrap();
+
+        let batch_stream = table_scan.to_arrow().await.unwrap();
+
+        let batches: Vec<_> = batch_stream.try_collect().await.unwrap();
+
+        assert_eq!(batches[0].num_rows(), 12);
+
+        let col = batches[0].column_by_name("x").unwrap();
+        let int64_arr = col.as_any().downcast_ref::<Int64Array>().unwrap();
+        assert_eq!(int64_arr.value(0), 1);
+
+        let col = batches[0].column_by_name("y").unwrap();
+        let int64_arr = col.as_any().downcast_ref::<Int64Array>().unwrap();
+        assert_eq!(int64_arr.value(0), 5);
+    }
+
+    #[tokio::test]
+    async fn test_filter_on_arrow_is_null() {
+        let mut fixture = TableTestFixture::new();
+        fixture.setup_manifest_files().await;
+
+        // Filter: y is null
+        let mut builder = fixture.table.scan();
+        let predicate = Reference::new("y").is_null();
+        builder = builder.filter(predicate);
+        let table_scan = builder.build().unwrap();
+
+        let batch_stream = table_scan.to_arrow().await.unwrap();
+
+        let batches: Vec<_> = batch_stream.try_collect().await.unwrap();
+        assert_eq!(batches.len(), 0);
+    }
+
+    #[tokio::test]
+    async fn test_filter_on_arrow_is_not_null() {
+        let mut fixture = TableTestFixture::new();
+        fixture.setup_manifest_files().await;
+
+        // Filter: y is not null
+        let mut builder = fixture.table.scan();
+        let predicate = Reference::new("y").is_not_null();
+        builder = builder.filter(predicate);
+        let table_scan = builder.build().unwrap();
+
+        let batch_stream = table_scan.to_arrow().await.unwrap();
+
+        let batches: Vec<_> = batch_stream.try_collect().await.unwrap();
+        assert_eq!(batches[0].num_rows(), 1024);
+    }
+
+    #[tokio::test]
+    async fn test_filter_on_arrow_lt_and_gt() {
+        let mut fixture = TableTestFixture::new();
+        fixture.setup_manifest_files().await;
+
+        // Filter: y < 5 AND z >= 4
+        let mut builder = fixture.table.scan();
+        let predicate = Reference::new("y")
+            .less_than(Datum::long(5))
+            .and(Reference::new("z").greater_than_or_equal_to(Datum::long(4)));
+        builder = builder.filter(predicate);
+        let table_scan = builder.build().unwrap();
+
+        let batch_stream = table_scan.to_arrow().await.unwrap();
+
+        let batches: Vec<_> = batch_stream.try_collect().await.unwrap();
+        assert_eq!(batches[0].num_rows(), 500);
+
+        let col = batches[0].column_by_name("x").unwrap();
+        let expected_x = Arc::new(Int64Array::from_iter_values(vec![1; 500])) 
as ArrayRef;
+        assert_eq!(col, &expected_x);
+
+        let col = batches[0].column_by_name("y").unwrap();
+        let mut values = vec![];
+        values.append(vec![3; 200].as_mut());
+        values.append(vec![4; 300].as_mut());
+        let expected_y = Arc::new(Int64Array::from_iter_values(values)) as 
ArrayRef;
+        assert_eq!(col, &expected_y);
+
+        let col = batches[0].column_by_name("z").unwrap();
+        let expected_z = Arc::new(Int64Array::from_iter_values(vec![4; 500])) 
as ArrayRef;
+        assert_eq!(col, &expected_z);
+    }
+
+    #[tokio::test]
+    async fn test_filter_on_arrow_lt_or_gt() {
+        let mut fixture = TableTestFixture::new();
+        fixture.setup_manifest_files().await;
+
+        // Filter: y < 5 AND z >= 4
+        let mut builder = fixture.table.scan();
+        let predicate = Reference::new("y")
+            .less_than(Datum::long(5))
+            .or(Reference::new("z").greater_than_or_equal_to(Datum::long(4)));
+        builder = builder.filter(predicate);
+        let table_scan = builder.build().unwrap();
+
+        let batch_stream = table_scan.to_arrow().await.unwrap();
+
+        let batches: Vec<_> = batch_stream.try_collect().await.unwrap();
+        assert_eq!(batches[0].num_rows(), 1024);
+
+        let col = batches[0].column_by_name("x").unwrap();
+        let expected_x = Arc::new(Int64Array::from_iter_values(vec![1; 1024])) 
as ArrayRef;
+        assert_eq!(col, &expected_x);
+
+        let col = batches[0].column_by_name("y").unwrap();
+        let mut values = vec![2; 512];
+        values.append(vec![3; 200].as_mut());
+        values.append(vec![4; 300].as_mut());
+        values.append(vec![5; 12].as_mut());
+        let expected_y = Arc::new(Int64Array::from_iter_values(values)) as 
ArrayRef;
+        assert_eq!(col, &expected_y);
+
+        let col = batches[0].column_by_name("z").unwrap();
+        let mut values = vec![3; 512];
+        values.append(vec![4; 512].as_mut());
+        let expected_z = Arc::new(Int64Array::from_iter_values(values)) as 
ArrayRef;
+        assert_eq!(col, &expected_z);
+    }
 }


Reply via email to