This is an automated email from the ASF dual-hosted git repository.
liurenjie1024 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-rust.git
The following commit(s) were added to refs/heads/main by this push:
new 81df940 feat: Convert predicate to arrow filter and push down to
parquet reader (#295)
81df940 is described below
commit 81df9406a4dd96de8ad82f5f9253738b44a3fe31
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Tue May 14 18:46:44 2024 -0700
feat: Convert predicate to arrow filter and push down to parquet reader
(#295)
* feat: Convert predicate to arrow filter and push down to parquet reader
* For review
* Fix clippy
* Change from vector of BoundPredicate to BoundPredicate
* Add test for CollectFieldIdVisitor
* Return projection_mask for leaf column
* Update
* For review
* For review
* For review
* For review
* More
* fix
* Fix clippy
* More
* Fix clippy
* fix clippy
---
Cargo.toml | 1 +
crates/iceberg/Cargo.toml | 1 +
crates/iceberg/src/arrow/reader.rs | 678 ++++++++++++++++++++++++++++++++++-
crates/iceberg/src/arrow/schema.rs | 26 +-
crates/iceberg/src/expr/predicate.rs | 9 +
crates/iceberg/src/scan.rs | 198 +++++++++-
6 files changed, 903 insertions(+), 10 deletions(-)
diff --git a/Cargo.toml b/Cargo.toml
index 57c3436..125cd0d 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -40,6 +40,7 @@ apache-avro = "0.16"
array-init = "2"
arrow-arith = { version = "51" }
arrow-array = { version = "51" }
+arrow-ord = { version = "51" }
arrow-schema = { version = "51" }
arrow-select = { version = "51" }
async-stream = "0.3.5"
diff --git a/crates/iceberg/Cargo.toml b/crates/iceberg/Cargo.toml
index 46f167b..95e0078 100644
--- a/crates/iceberg/Cargo.toml
+++ b/crates/iceberg/Cargo.toml
@@ -34,6 +34,7 @@ apache-avro = { workspace = true }
array-init = { workspace = true }
arrow-arith = { workspace = true }
arrow-array = { workspace = true }
+arrow-ord = { workspace = true }
arrow-schema = { workspace = true }
arrow-select = { workspace = true }
async-stream = { workspace = true }
diff --git a/crates/iceberg/src/arrow/reader.rs
b/crates/iceberg/src/arrow/reader.rs
index fe5efac..391239c 100644
--- a/crates/iceberg/src/arrow/reader.rs
+++ b/crates/iceberg/src/arrow/reader.rs
@@ -17,25 +17,33 @@
//! Parquet file data reader
-use arrow_schema::SchemaRef as ArrowSchemaRef;
+use crate::error::Result;
+use arrow_arith::boolean::{and, is_not_null, is_null, not, or};
+use arrow_array::{ArrayRef, BooleanArray, RecordBatch};
+use arrow_ord::cmp::{eq, gt, gt_eq, lt, lt_eq, neq};
+use arrow_schema::{ArrowError, DataType, SchemaRef as ArrowSchemaRef};
use async_stream::try_stream;
use bytes::Bytes;
+use fnv::FnvHashSet;
use futures::future::BoxFuture;
use futures::stream::StreamExt;
use futures::{try_join, TryFutureExt};
+use parquet::arrow::arrow_reader::{ArrowPredicateFn, RowFilter};
use parquet::arrow::async_reader::{AsyncFileReader, MetadataLoader};
use parquet::arrow::{ParquetRecordBatchStreamBuilder, ProjectionMask,
PARQUET_FIELD_ID_META_KEY};
use parquet::file::metadata::ParquetMetaData;
-use parquet::schema::types::SchemaDescriptor;
-use std::collections::HashMap;
+use parquet::schema::types::{SchemaDescriptor, Type as ParquetType};
+use std::collections::{HashMap, HashSet};
use std::ops::Range;
use std::str::FromStr;
use std::sync::Arc;
-use crate::arrow::arrow_schema_to_schema;
+use crate::arrow::{arrow_schema_to_schema, get_arrow_datum};
+use crate::expr::visitors::bound_predicate_visitor::{visit,
BoundPredicateVisitor};
+use crate::expr::{BoundPredicate, BoundReference};
use crate::io::{FileIO, FileMetadata, FileRead};
use crate::scan::{ArrowRecordBatchStream, FileScanTaskStream};
-use crate::spec::SchemaRef;
+use crate::spec::{Datum, SchemaRef};
use crate::{Error, ErrorKind};
/// Builder to create ArrowReader
@@ -44,6 +52,7 @@ pub struct ArrowReaderBuilder {
field_ids: Vec<usize>,
file_io: FileIO,
schema: SchemaRef,
+ predicates: Option<BoundPredicate>,
}
impl ArrowReaderBuilder {
@@ -54,6 +63,7 @@ impl ArrowReaderBuilder {
field_ids: vec![],
file_io,
schema,
+ predicates: None,
}
}
@@ -70,6 +80,12 @@ impl ArrowReaderBuilder {
self
}
+ /// Sets the predicates to apply to the scan.
+ pub fn with_predicates(mut self, predicates: BoundPredicate) -> Self {
+ self.predicates = Some(predicates);
+ self
+ }
+
/// Build the ArrowReader.
pub fn build(self) -> ArrowReader {
ArrowReader {
@@ -77,6 +93,7 @@ impl ArrowReaderBuilder {
field_ids: self.field_ids,
schema: self.schema,
file_io: self.file_io,
+ predicates: self.predicates,
}
}
}
@@ -88,6 +105,7 @@ pub struct ArrowReader {
#[allow(dead_code)]
schema: SchemaRef,
file_io: FileIO,
+ predicates: Option<BoundPredicate>,
}
impl ArrowReader {
@@ -96,6 +114,14 @@ impl ArrowReader {
pub fn read(self, mut tasks: FileScanTaskStream) ->
crate::Result<ArrowRecordBatchStream> {
let file_io = self.file_io.clone();
+ // Collect Parquet column indices from field ids
+ let mut collector = CollectFieldIdVisitor {
+ field_ids: HashSet::default(),
+ };
+ if let Some(predicates) = &self.predicates {
+ visit(&mut collector, predicates)?;
+ }
+
Ok(try_stream! {
while let Some(Ok(task)) = tasks.next().await {
let parquet_file = file_io
@@ -111,6 +137,13 @@ impl ArrowReader {
let projection_mask =
self.get_arrow_projection_mask(parquet_schema, arrow_schema)?;
batch_stream_builder =
batch_stream_builder.with_projection(projection_mask);
+ let parquet_schema = batch_stream_builder.parquet_schema();
+ let row_filter = self.get_row_filter(parquet_schema,
&collector)?;
+
+ if let Some(row_filter) = row_filter {
+ batch_stream_builder =
batch_stream_builder.with_row_filter(row_filter);
+ }
+
if let Some(batch_size) = self.batch_size {
batch_stream_builder =
batch_stream_builder.with_batch_size(batch_size);
}
@@ -193,6 +226,558 @@ impl ArrowReader {
Ok(ProjectionMask::leaves(parquet_schema, indices))
}
}
+
+ fn get_row_filter(
+ &self,
+ parquet_schema: &SchemaDescriptor,
+ collector: &CollectFieldIdVisitor,
+ ) -> Result<Option<RowFilter>> {
+ if let Some(predicates) = &self.predicates {
+ let field_id_map = build_field_id_map(parquet_schema)?;
+
+ // Collect Parquet column indices from field ids.
+ // If the field id is not found in Parquet schema, it will be
ignored due to schema evolution.
+ let mut column_indices = collector
+ .field_ids
+ .iter()
+ .filter_map(|field_id| field_id_map.get(field_id).cloned())
+ .collect::<Vec<_>>();
+
+ column_indices.sort();
+
+ // The converter that converts `BoundPredicates` to
`ArrowPredicates`
+ let mut converter = PredicateConverter {
+ parquet_schema,
+ column_map: &field_id_map,
+ column_indices: &column_indices,
+ };
+
+ // After collecting required leaf column indices used in the
predicate,
+ // creates the projection mask for the Arrow predicates.
+ let projection_mask = ProjectionMask::leaves(parquet_schema,
column_indices.clone());
+ let predicate_func = visit(&mut converter, predicates)?;
+ let arrow_predicate = ArrowPredicateFn::new(projection_mask,
predicate_func);
+ Ok(Some(RowFilter::new(vec![Box::new(arrow_predicate)])))
+ } else {
+ Ok(None)
+ }
+ }
+}
+
+/// Build the map of field id to Parquet column index in the schema.
+fn build_field_id_map(parquet_schema: &SchemaDescriptor) ->
Result<HashMap<i32, usize>> {
+ let mut column_map = HashMap::new();
+ for (idx, field) in parquet_schema.columns().iter().enumerate() {
+ let field_type = field.self_type();
+ match field_type {
+ ParquetType::PrimitiveType { basic_info, .. } => {
+ if !basic_info.has_id() {
+ return Err(Error::new(
+ ErrorKind::DataInvalid,
+ format!(
+ "Leave column idx: {}, name: {}, type {:?} in
schema doesn't have field id",
+ idx,
+ basic_info.name(),
+ field_type
+ ),
+ ));
+ }
+ column_map.insert(basic_info.id(), idx);
+ }
+ ParquetType::GroupType { .. } => {
+ return Err(Error::new(
+ ErrorKind::DataInvalid,
+ format!(
+ "Leave column in schema should be primitive type but
got {:?}",
+ field_type
+ ),
+ ));
+ }
+ };
+ }
+
+ Ok(column_map)
+}
+
+/// A visitor to collect field ids from bound predicates.
+struct CollectFieldIdVisitor {
+ field_ids: HashSet<i32>,
+}
+
+impl BoundPredicateVisitor for CollectFieldIdVisitor {
+ type T = ();
+
+ fn always_true(&mut self) -> Result<()> {
+ Ok(())
+ }
+
+ fn always_false(&mut self) -> Result<()> {
+ Ok(())
+ }
+
+ fn and(&mut self, _lhs: (), _rhs: ()) -> Result<()> {
+ Ok(())
+ }
+
+ fn or(&mut self, _lhs: (), _rhs: ()) -> Result<()> {
+ Ok(())
+ }
+
+ fn not(&mut self, _inner: ()) -> Result<()> {
+ Ok(())
+ }
+
+ fn is_null(&mut self, reference: &BoundReference, _predicate:
&BoundPredicate) -> Result<()> {
+ self.field_ids.insert(reference.field().id);
+ Ok(())
+ }
+
+ fn not_null(&mut self, reference: &BoundReference, _predicate:
&BoundPredicate) -> Result<()> {
+ self.field_ids.insert(reference.field().id);
+ Ok(())
+ }
+
+ fn is_nan(&mut self, reference: &BoundReference, _predicate:
&BoundPredicate) -> Result<()> {
+ self.field_ids.insert(reference.field().id);
+ Ok(())
+ }
+
+ fn not_nan(&mut self, reference: &BoundReference, _predicate:
&BoundPredicate) -> Result<()> {
+ self.field_ids.insert(reference.field().id);
+ Ok(())
+ }
+
+ fn less_than(
+ &mut self,
+ reference: &BoundReference,
+ _literal: &Datum,
+ _predicate: &BoundPredicate,
+ ) -> Result<()> {
+ self.field_ids.insert(reference.field().id);
+ Ok(())
+ }
+
+ fn less_than_or_eq(
+ &mut self,
+ reference: &BoundReference,
+ _literal: &Datum,
+ _predicate: &BoundPredicate,
+ ) -> Result<()> {
+ self.field_ids.insert(reference.field().id);
+ Ok(())
+ }
+
+ fn greater_than(
+ &mut self,
+ reference: &BoundReference,
+ _literal: &Datum,
+ _predicate: &BoundPredicate,
+ ) -> Result<()> {
+ self.field_ids.insert(reference.field().id);
+ Ok(())
+ }
+
+ fn greater_than_or_eq(
+ &mut self,
+ reference: &BoundReference,
+ _literal: &Datum,
+ _predicate: &BoundPredicate,
+ ) -> Result<()> {
+ self.field_ids.insert(reference.field().id);
+ Ok(())
+ }
+
+ fn eq(
+ &mut self,
+ reference: &BoundReference,
+ _literal: &Datum,
+ _predicate: &BoundPredicate,
+ ) -> Result<()> {
+ self.field_ids.insert(reference.field().id);
+ Ok(())
+ }
+
+ fn not_eq(
+ &mut self,
+ reference: &BoundReference,
+ _literal: &Datum,
+ _predicate: &BoundPredicate,
+ ) -> Result<()> {
+ self.field_ids.insert(reference.field().id);
+ Ok(())
+ }
+
+ fn starts_with(
+ &mut self,
+ reference: &BoundReference,
+ _literal: &Datum,
+ _predicate: &BoundPredicate,
+ ) -> Result<()> {
+ self.field_ids.insert(reference.field().id);
+ Ok(())
+ }
+
+ fn not_starts_with(
+ &mut self,
+ reference: &BoundReference,
+ _literal: &Datum,
+ _predicate: &BoundPredicate,
+ ) -> Result<()> {
+ self.field_ids.insert(reference.field().id);
+ Ok(())
+ }
+
+ fn r#in(
+ &mut self,
+ reference: &BoundReference,
+ _literals: &FnvHashSet<Datum>,
+ _predicate: &BoundPredicate,
+ ) -> Result<()> {
+ self.field_ids.insert(reference.field().id);
+ Ok(())
+ }
+
+ fn not_in(
+ &mut self,
+ reference: &BoundReference,
+ _literals: &FnvHashSet<Datum>,
+ _predicate: &BoundPredicate,
+ ) -> Result<()> {
+ self.field_ids.insert(reference.field().id);
+ Ok(())
+ }
+}
+
+/// A visitor to convert Iceberg bound predicates to Arrow predicates.
+struct PredicateConverter<'a> {
+ /// The Parquet schema descriptor.
+ pub parquet_schema: &'a SchemaDescriptor,
+ /// The map between field id and leaf column index in Parquet schema.
+ pub column_map: &'a HashMap<i32, usize>,
+ /// The required column indices in Parquet schema for the predicates.
+ pub column_indices: &'a Vec<usize>,
+}
+
+impl PredicateConverter<'_> {
+ /// When visiting a bound reference, we return index of the leaf column in
the
+ /// required column indices which is used to project the column in the
record batch.
+ /// Return None if the field id is not found in the column map, which is
possible
+ /// due to schema evolution.
+ fn bound_reference(&mut self, reference: &BoundReference) ->
Result<Option<usize>> {
+ // The leaf column's index in Parquet schema.
+ if let Some(column_idx) = self.column_map.get(&reference.field().id) {
+ if self.parquet_schema.get_column_root_idx(*column_idx) !=
*column_idx {
+ return Err(Error::new(
+ ErrorKind::DataInvalid,
+ format!(
+ "Leave column `{}` in predicates isn't a root column
in Parquet schema.",
+ reference.field().name
+ ),
+ ));
+ }
+
+ // The leaf column's index in the required column indices.
+ let index = self
+ .column_indices
+ .iter()
+ .position(|&idx| idx ==
*column_idx).ok_or(Error::new(ErrorKind::DataInvalid, format!(
+ "Leave column `{}` in predicates cannot be found in the
required column indices.",
+ reference.field().name
+ )))?;
+
+ Ok(Some(index))
+ } else {
+ Ok(None)
+ }
+ }
+
+ /// Build an Arrow predicate that always returns true.
+ fn build_always_true(&self) -> Result<Box<PredicateResult>> {
+ Ok(Box::new(|batch| {
+ Ok(BooleanArray::from(vec![true; batch.num_rows()]))
+ }))
+ }
+
+ /// Build an Arrow predicate that always returns false.
+ fn build_always_false(&self) -> Result<Box<PredicateResult>> {
+ Ok(Box::new(|batch| {
+ Ok(BooleanArray::from(vec![false; batch.num_rows()]))
+ }))
+ }
+}
+
+/// Gets the leaf column from the record batch for the required column index.
Only
+/// supports top-level columns for now.
+fn project_column(
+ batch: &RecordBatch,
+ column_idx: usize,
+) -> std::result::Result<ArrayRef, ArrowError> {
+ let column = batch.column(column_idx);
+
+ match column.data_type() {
+ DataType::Struct(_) => Err(ArrowError::SchemaError(
+ "Does not support struct column yet.".to_string(),
+ )),
+ _ => Ok(column.clone()),
+ }
+}
+
+type PredicateResult =
+ dyn FnMut(RecordBatch) -> std::result::Result<BooleanArray, ArrowError> +
Send + 'static;
+
+impl<'a> BoundPredicateVisitor for PredicateConverter<'a> {
+ type T = Box<PredicateResult>;
+
+ fn always_true(&mut self) -> Result<Box<PredicateResult>> {
+ self.build_always_true()
+ }
+
+ fn always_false(&mut self) -> Result<Box<PredicateResult>> {
+ self.build_always_false()
+ }
+
+ fn and(
+ &mut self,
+ mut lhs: Box<PredicateResult>,
+ mut rhs: Box<PredicateResult>,
+ ) -> Result<Box<PredicateResult>> {
+ Ok(Box::new(move |batch| {
+ let left = lhs(batch.clone())?;
+ let right = rhs(batch)?;
+ and(&left, &right)
+ }))
+ }
+
+ fn or(
+ &mut self,
+ mut lhs: Box<PredicateResult>,
+ mut rhs: Box<PredicateResult>,
+ ) -> Result<Box<PredicateResult>> {
+ Ok(Box::new(move |batch| {
+ let left = lhs(batch.clone())?;
+ let right = rhs(batch)?;
+ or(&left, &right)
+ }))
+ }
+
+ fn not(&mut self, mut inner: Box<PredicateResult>) ->
Result<Box<PredicateResult>> {
+ Ok(Box::new(move |batch| {
+ let pred_ret = inner(batch)?;
+ not(&pred_ret)
+ }))
+ }
+
+ fn is_null(
+ &mut self,
+ reference: &BoundReference,
+ _predicate: &BoundPredicate,
+ ) -> Result<Box<PredicateResult>> {
+ if let Some(idx) = self.bound_reference(reference)? {
+ Ok(Box::new(move |batch| {
+ let column = project_column(&batch, idx)?;
+ is_null(&column)
+ }))
+ } else {
+ // A missing column, treating it as null.
+ self.build_always_true()
+ }
+ }
+
+ fn not_null(
+ &mut self,
+ reference: &BoundReference,
+ _predicate: &BoundPredicate,
+ ) -> Result<Box<PredicateResult>> {
+ if let Some(idx) = self.bound_reference(reference)? {
+ Ok(Box::new(move |batch| {
+ let column = project_column(&batch, idx)?;
+ is_not_null(&column)
+ }))
+ } else {
+ // A missing column, treating it as null.
+ self.build_always_false()
+ }
+ }
+
+ fn is_nan(
+ &mut self,
+ reference: &BoundReference,
+ _predicate: &BoundPredicate,
+ ) -> Result<Box<PredicateResult>> {
+ if self.bound_reference(reference)?.is_some() {
+ self.build_always_true()
+ } else {
+ // A missing column, treating it as null.
+ self.build_always_false()
+ }
+ }
+
+ fn not_nan(
+ &mut self,
+ reference: &BoundReference,
+ _predicate: &BoundPredicate,
+ ) -> Result<Box<PredicateResult>> {
+ if self.bound_reference(reference)?.is_some() {
+ self.build_always_false()
+ } else {
+ // A missing column, treating it as null.
+ self.build_always_true()
+ }
+ }
+
+ fn less_than(
+ &mut self,
+ reference: &BoundReference,
+ literal: &Datum,
+ _predicate: &BoundPredicate,
+ ) -> Result<Box<PredicateResult>> {
+ if let Some(idx) = self.bound_reference(reference)? {
+ let literal = get_arrow_datum(literal)?;
+
+ Ok(Box::new(move |batch| {
+ let left = project_column(&batch, idx)?;
+ lt(&left, literal.as_ref())
+ }))
+ } else {
+ // A missing column, treating it as null.
+ self.build_always_true()
+ }
+ }
+
+ fn less_than_or_eq(
+ &mut self,
+ reference: &BoundReference,
+ literal: &Datum,
+ _predicate: &BoundPredicate,
+ ) -> Result<Box<PredicateResult>> {
+ if let Some(idx) = self.bound_reference(reference)? {
+ let literal = get_arrow_datum(literal)?;
+
+ Ok(Box::new(move |batch| {
+ let left = project_column(&batch, idx)?;
+ lt_eq(&left, literal.as_ref())
+ }))
+ } else {
+ // A missing column, treating it as null.
+ self.build_always_true()
+ }
+ }
+
+ fn greater_than(
+ &mut self,
+ reference: &BoundReference,
+ literal: &Datum,
+ _predicate: &BoundPredicate,
+ ) -> Result<Box<PredicateResult>> {
+ if let Some(idx) = self.bound_reference(reference)? {
+ let literal = get_arrow_datum(literal)?;
+
+ Ok(Box::new(move |batch| {
+ let left = project_column(&batch, idx)?;
+ gt(&left, literal.as_ref())
+ }))
+ } else {
+ // A missing column, treating it as null.
+ self.build_always_false()
+ }
+ }
+
+ fn greater_than_or_eq(
+ &mut self,
+ reference: &BoundReference,
+ literal: &Datum,
+ _predicate: &BoundPredicate,
+ ) -> Result<Box<PredicateResult>> {
+ if let Some(idx) = self.bound_reference(reference)? {
+ let literal = get_arrow_datum(literal)?;
+
+ Ok(Box::new(move |batch| {
+ let left = project_column(&batch, idx)?;
+ gt_eq(&left, literal.as_ref())
+ }))
+ } else {
+ // A missing column, treating it as null.
+ self.build_always_false()
+ }
+ }
+
+ fn eq(
+ &mut self,
+ reference: &BoundReference,
+ literal: &Datum,
+ _predicate: &BoundPredicate,
+ ) -> Result<Box<PredicateResult>> {
+ if let Some(idx) = self.bound_reference(reference)? {
+ let literal = get_arrow_datum(literal)?;
+
+ Ok(Box::new(move |batch| {
+ let left = project_column(&batch, idx)?;
+ eq(&left, literal.as_ref())
+ }))
+ } else {
+ // A missing column, treating it as null.
+ self.build_always_false()
+ }
+ }
+
+ fn not_eq(
+ &mut self,
+ reference: &BoundReference,
+ literal: &Datum,
+ _predicate: &BoundPredicate,
+ ) -> Result<Box<PredicateResult>> {
+ if let Some(idx) = self.bound_reference(reference)? {
+ let literal = get_arrow_datum(literal)?;
+
+ Ok(Box::new(move |batch| {
+ let left = project_column(&batch, idx)?;
+ neq(&left, literal.as_ref())
+ }))
+ } else {
+ // A missing column, treating it as null.
+ self.build_always_false()
+ }
+ }
+
+ fn starts_with(
+ &mut self,
+ _reference: &BoundReference,
+ _literal: &Datum,
+ _predicate: &BoundPredicate,
+ ) -> Result<Box<PredicateResult>> {
+ // TODO: Implement starts_with
+ self.build_always_true()
+ }
+
+ fn not_starts_with(
+ &mut self,
+ _reference: &BoundReference,
+ _literal: &Datum,
+ _predicate: &BoundPredicate,
+ ) -> Result<Box<PredicateResult>> {
+ // TODO: Implement not_starts_with
+ self.build_always_true()
+ }
+
+ fn r#in(
+ &mut self,
+ _reference: &BoundReference,
+ _literals: &FnvHashSet<Datum>,
+ _predicate: &BoundPredicate,
+ ) -> Result<Box<PredicateResult>> {
+ // TODO: Implement in
+ self.build_always_true()
+ }
+
+ fn not_in(
+ &mut self,
+ _reference: &BoundReference,
+ _literals: &FnvHashSet<Datum>,
+ _predicate: &BoundPredicate,
+ ) -> Result<Box<PredicateResult>> {
+ // TODO: Implement not_in
+ self.build_always_true()
+ }
}
/// ArrowFileReader is a wrapper around a FileRead that impls parquets
AsyncFileReader.
@@ -234,3 +819,86 @@ impl<R: FileRead> AsyncFileReader for ArrowFileReader<R> {
})
}
}
+
+#[cfg(test)]
+mod tests {
+ use crate::arrow::reader::CollectFieldIdVisitor;
+ use crate::expr::visitors::bound_predicate_visitor::visit;
+ use crate::expr::{Bind, Reference};
+ use crate::spec::{NestedField, PrimitiveType, Schema, SchemaRef, Type};
+ use std::collections::HashSet;
+ use std::sync::Arc;
+
+ fn table_schema_simple() -> SchemaRef {
+ Arc::new(
+ Schema::builder()
+ .with_schema_id(1)
+ .with_identifier_field_ids(vec![2])
+ .with_fields(vec![
+ NestedField::optional(1, "foo",
Type::Primitive(PrimitiveType::String)).into(),
+ NestedField::required(2, "bar",
Type::Primitive(PrimitiveType::Int)).into(),
+ NestedField::optional(3, "baz",
Type::Primitive(PrimitiveType::Boolean)).into(),
+ NestedField::optional(4, "qux",
Type::Primitive(PrimitiveType::Float)).into(),
+ ])
+ .build()
+ .unwrap(),
+ )
+ }
+
+ #[test]
+ fn test_collect_field_id() {
+ let schema = table_schema_simple();
+ let expr = Reference::new("qux").is_null();
+ let bound_expr = expr.bind(schema, true).unwrap();
+
+ let mut visitor = CollectFieldIdVisitor {
+ field_ids: HashSet::default(),
+ };
+ visit(&mut visitor, &bound_expr).unwrap();
+
+ let mut expected = HashSet::default();
+ expected.insert(4_i32);
+
+ assert_eq!(visitor.field_ids, expected);
+ }
+
+ #[test]
+ fn test_collect_field_id_with_and() {
+ let schema = table_schema_simple();
+ let expr = Reference::new("qux")
+ .is_null()
+ .and(Reference::new("baz").is_null());
+ let bound_expr = expr.bind(schema, true).unwrap();
+
+ let mut visitor = CollectFieldIdVisitor {
+ field_ids: HashSet::default(),
+ };
+ visit(&mut visitor, &bound_expr).unwrap();
+
+ let mut expected = HashSet::default();
+ expected.insert(4_i32);
+ expected.insert(3);
+
+ assert_eq!(visitor.field_ids, expected);
+ }
+
+ #[test]
+ fn test_collect_field_id_with_or() {
+ let schema = table_schema_simple();
+ let expr = Reference::new("qux")
+ .is_null()
+ .or(Reference::new("baz").is_null());
+ let bound_expr = expr.bind(schema, true).unwrap();
+
+ let mut visitor = CollectFieldIdVisitor {
+ field_ids: HashSet::default(),
+ };
+ visit(&mut visitor, &bound_expr).unwrap();
+
+ let mut expected = HashSet::default();
+ expected.insert(4_i32);
+ expected.insert(3);
+
+ assert_eq!(visitor.field_ids, expected);
+ }
+}
diff --git a/crates/iceberg/src/arrow/schema.rs
b/crates/iceberg/src/arrow/schema.rs
index c7e8700..172d4bb 100644
--- a/crates/iceberg/src/arrow/schema.rs
+++ b/crates/iceberg/src/arrow/schema.rs
@@ -19,12 +19,16 @@
use crate::error::Result;
use crate::spec::{
- ListType, MapType, NestedField, NestedFieldRef, PrimitiveType, Schema,
SchemaVisitor,
- StructType, Type,
+ Datum, ListType, MapType, NestedField, NestedFieldRef, PrimitiveLiteral,
PrimitiveType, Schema,
+ SchemaVisitor, StructType, Type,
};
use crate::{Error, ErrorKind};
use arrow_array::types::{validate_decimal_precision_and_scale, Decimal128Type};
+use arrow_array::{
+ BooleanArray, Datum as ArrowDatum, Float32Array, Float64Array, Int32Array,
Int64Array,
+};
use arrow_schema::{DataType, Field, Fields, Schema as ArrowSchema, TimeUnit};
+use bitvec::macros::internal::funty::Fundamental;
use parquet::arrow::PARQUET_FIELD_ID_META_KEY;
use rust_decimal::prelude::ToPrimitive;
use std::collections::HashMap;
@@ -593,6 +597,24 @@ pub fn schema_to_arrow_schema(schema:
&crate::spec::Schema) -> crate::Result<Arr
}
}
+/// Convert Iceberg Datum to Arrow Datum.
+pub(crate) fn get_arrow_datum(datum: &Datum) -> Result<Box<dyn ArrowDatum +
Send>> {
+ match datum.literal() {
+ PrimitiveLiteral::Boolean(value) =>
Ok(Box::new(BooleanArray::new_scalar(*value))),
+ PrimitiveLiteral::Int(value) =>
Ok(Box::new(Int32Array::new_scalar(*value))),
+ PrimitiveLiteral::Long(value) =>
Ok(Box::new(Int64Array::new_scalar(*value))),
+ PrimitiveLiteral::Float(value) =>
Ok(Box::new(Float32Array::new_scalar(value.as_f32()))),
+ PrimitiveLiteral::Double(value) =>
Ok(Box::new(Float64Array::new_scalar(value.as_f64()))),
+ l => Err(Error::new(
+ ErrorKind::FeatureUnsupported,
+ format!(
+ "Converting datum from type {:?} to arrow not supported yet.",
+ l
+ ),
+ )),
+ }
+}
+
impl TryFrom<&ArrowSchema> for crate::spec::Schema {
type Error = Error;
diff --git a/crates/iceberg/src/expr/predicate.rs
b/crates/iceberg/src/expr/predicate.rs
index 1457d5a..158ab13 100644
--- a/crates/iceberg/src/expr/predicate.rs
+++ b/crates/iceberg/src/expr/predicate.rs
@@ -116,10 +116,13 @@ impl<T> UnaryExpression<T> {
debug_assert!(op.is_unary());
Self { op, term }
}
+
+ /// Return the operator of this predicate.
pub(crate) fn op(&self) -> PredicateOperator {
self.op
}
+ /// Return the term of this predicate.
pub(crate) fn term(&self) -> &T {
&self.term
}
@@ -155,10 +158,13 @@ impl<T> BinaryExpression<T> {
pub(crate) fn op(&self) -> PredicateOperator {
self.op
}
+
+ /// Return the literal of this predicate.
pub(crate) fn literal(&self) -> &Datum {
&self.literal
}
+ /// Return the term of this predicate.
pub(crate) fn term(&self) -> &T {
&self.term
}
@@ -210,13 +216,16 @@ impl<T> SetExpression<T> {
Self { op, term, literals }
}
+ /// Return the operator of this predicate.
pub(crate) fn op(&self) -> PredicateOperator {
self.op
}
+
pub(crate) fn literals(&self) -> &FnvHashSet<Datum> {
&self.literals
}
+ /// Return the term of this predicate.
pub(crate) fn term(&self) -> &T {
&self.term
}
diff --git a/crates/iceberg/src/scan.rs b/crates/iceberg/src/scan.rs
index c2a5e1b..906bf96 100644
--- a/crates/iceberg/src/scan.rs
+++ b/crates/iceberg/src/scan.rs
@@ -46,6 +46,7 @@ pub struct TableScanBuilder<'a> {
table: &'a Table,
// Empty column names means to select all columns
column_names: Vec<String>,
+ predicates: Option<Predicate>,
snapshot_id: Option<i64>,
batch_size: Option<usize>,
case_sensitive: bool,
@@ -57,6 +58,7 @@ impl<'a> TableScanBuilder<'a> {
Self {
table,
column_names: vec![],
+ predicates: None,
snapshot_id: None,
batch_size: None,
case_sensitive: true,
@@ -91,6 +93,12 @@ impl<'a> TableScanBuilder<'a> {
self
}
+ /// Add a predicate to the scan. The scan will only return rows that match
the predicate.
+ pub fn filter(mut self, predicate: Predicate) -> Self {
+ self.predicates = Some(predicate);
+ self
+ }
+
/// Select some columns of the table.
pub fn select(mut self, column_names: impl IntoIterator<Item = impl
ToString>) -> Self {
self.column_names = column_names
@@ -150,11 +158,18 @@ impl<'a> TableScanBuilder<'a> {
}
}
+ let bound_predicates = if let Some(ref predicates) = self.predicates {
+ Some(predicates.bind(schema.clone(), true)?)
+ } else {
+ None
+ };
+
Ok(TableScan {
snapshot,
file_io: self.table.file_io().clone(),
table_metadata: self.table.metadata_ref(),
column_names: self.column_names,
+ bound_predicates,
schema,
batch_size: self.batch_size,
case_sensitive: self.case_sensitive,
@@ -171,6 +186,7 @@ pub struct TableScan {
table_metadata: TableMetadataRef,
file_io: FileIO,
column_names: Vec<String>,
+ bound_predicates: Option<BoundPredicate>,
schema: SchemaRef,
batch_size: Option<usize>,
case_sensitive: bool,
@@ -300,6 +316,10 @@ impl TableScan {
arrow_reader_builder =
arrow_reader_builder.with_batch_size(batch_size);
}
+ if let Some(ref bound_predicates) = self.bound_predicates {
+ arrow_reader_builder =
arrow_reader_builder.with_predicates(bound_predicates.clone());
+ }
+
arrow_reader_builder.build().read(self.plan_files().await?)
}
@@ -449,9 +469,10 @@ impl FileScanTask {
#[cfg(test)]
mod tests {
+ use crate::expr::Reference;
use crate::io::{FileIO, OutputFile};
use crate::spec::{
- DataContentType, DataFileBuilder, DataFileFormat, FormatVersion,
Literal, Manifest,
+ DataContentType, DataFileBuilder, DataFileFormat, Datum,
FormatVersion, Literal, Manifest,
ManifestContentType, ManifestEntry, ManifestListWriter,
ManifestMetadata, ManifestStatus,
ManifestWriter, Struct, TableMetadata, EMPTY_SNAPSHOT_ID,
};
@@ -642,9 +663,23 @@ mod tests {
];
Arc::new(arrow_schema::Schema::new(fields))
};
+ // 3 columns:
+ // x: [1, 1, 1, 1, ...]
let col1 = Arc::new(Int64Array::from_iter_values(vec![1; 1024]))
as ArrayRef;
- let col2 = Arc::new(Int64Array::from_iter_values(vec![2; 1024]))
as ArrayRef;
- let col3 = Arc::new(Int64Array::from_iter_values(vec![3; 1024]))
as ArrayRef;
+
+ let mut values = vec![2; 512];
+ values.append(vec![3; 200].as_mut());
+ values.append(vec![4; 300].as_mut());
+ values.append(vec![5; 12].as_mut());
+
+ // y: [2, 2, 2, 2, ..., 3, 3, 3, 3, ..., 4, 4, 4, 4, ..., 5, 5, 5,
5]
+ let col2 = Arc::new(Int64Array::from_iter_values(values)) as
ArrayRef;
+
+ let mut values = vec![3; 512];
+ values.append(vec![4; 512].as_mut());
+
+ // z: [3, 3, 3, 3, ..., 4, 4, 4, 4]
+ let col3 = Arc::new(Int64Array::from_iter_values(values)) as
ArrayRef;
let to_write = RecordBatch::try_new(schema.clone(), vec![col1,
col2, col3]).unwrap();
// Write the Parquet files
@@ -803,4 +838,161 @@ mod tests {
let int64_arr = col2.as_any().downcast_ref::<Int64Array>().unwrap();
assert_eq!(int64_arr.value(0), 3);
}
+
+ #[tokio::test]
+ async fn test_filter_on_arrow_lt() {
+ let mut fixture = TableTestFixture::new();
+ fixture.setup_manifest_files().await;
+
+ // Filter: y < 3
+ let mut builder = fixture.table.scan();
+ let predicate = Reference::new("y").less_than(Datum::long(3));
+ builder = builder.filter(predicate);
+ let table_scan = builder.build().unwrap();
+
+ let batch_stream = table_scan.to_arrow().await.unwrap();
+
+ let batches: Vec<_> = batch_stream.try_collect().await.unwrap();
+
+ assert_eq!(batches[0].num_rows(), 512);
+
+ let col = batches[0].column_by_name("x").unwrap();
+ let int64_arr = col.as_any().downcast_ref::<Int64Array>().unwrap();
+ assert_eq!(int64_arr.value(0), 1);
+
+ let col = batches[0].column_by_name("y").unwrap();
+ let int64_arr = col.as_any().downcast_ref::<Int64Array>().unwrap();
+ assert_eq!(int64_arr.value(0), 2);
+ }
+
+ #[tokio::test]
+ async fn test_filter_on_arrow_gt_eq() {
+ let mut fixture = TableTestFixture::new();
+ fixture.setup_manifest_files().await;
+
+ // Filter: y >= 5
+ let mut builder = fixture.table.scan();
+ let predicate =
Reference::new("y").greater_than_or_equal_to(Datum::long(5));
+ builder = builder.filter(predicate);
+ let table_scan = builder.build().unwrap();
+
+ let batch_stream = table_scan.to_arrow().await.unwrap();
+
+ let batches: Vec<_> = batch_stream.try_collect().await.unwrap();
+
+ assert_eq!(batches[0].num_rows(), 12);
+
+ let col = batches[0].column_by_name("x").unwrap();
+ let int64_arr = col.as_any().downcast_ref::<Int64Array>().unwrap();
+ assert_eq!(int64_arr.value(0), 1);
+
+ let col = batches[0].column_by_name("y").unwrap();
+ let int64_arr = col.as_any().downcast_ref::<Int64Array>().unwrap();
+ assert_eq!(int64_arr.value(0), 5);
+ }
+
+ #[tokio::test]
+ async fn test_filter_on_arrow_is_null() {
+ let mut fixture = TableTestFixture::new();
+ fixture.setup_manifest_files().await;
+
+ // Filter: y is null
+ let mut builder = fixture.table.scan();
+ let predicate = Reference::new("y").is_null();
+ builder = builder.filter(predicate);
+ let table_scan = builder.build().unwrap();
+
+ let batch_stream = table_scan.to_arrow().await.unwrap();
+
+ let batches: Vec<_> = batch_stream.try_collect().await.unwrap();
+ assert_eq!(batches.len(), 0);
+ }
+
+ #[tokio::test]
+ async fn test_filter_on_arrow_is_not_null() {
+ let mut fixture = TableTestFixture::new();
+ fixture.setup_manifest_files().await;
+
+ // Filter: y is not null
+ let mut builder = fixture.table.scan();
+ let predicate = Reference::new("y").is_not_null();
+ builder = builder.filter(predicate);
+ let table_scan = builder.build().unwrap();
+
+ let batch_stream = table_scan.to_arrow().await.unwrap();
+
+ let batches: Vec<_> = batch_stream.try_collect().await.unwrap();
+ assert_eq!(batches[0].num_rows(), 1024);
+ }
+
+ #[tokio::test]
+ async fn test_filter_on_arrow_lt_and_gt() {
+ let mut fixture = TableTestFixture::new();
+ fixture.setup_manifest_files().await;
+
+ // Filter: y < 5 AND z >= 4
+ let mut builder = fixture.table.scan();
+ let predicate = Reference::new("y")
+ .less_than(Datum::long(5))
+ .and(Reference::new("z").greater_than_or_equal_to(Datum::long(4)));
+ builder = builder.filter(predicate);
+ let table_scan = builder.build().unwrap();
+
+ let batch_stream = table_scan.to_arrow().await.unwrap();
+
+ let batches: Vec<_> = batch_stream.try_collect().await.unwrap();
+ assert_eq!(batches[0].num_rows(), 500);
+
+ let col = batches[0].column_by_name("x").unwrap();
+ let expected_x = Arc::new(Int64Array::from_iter_values(vec![1; 500]))
as ArrayRef;
+ assert_eq!(col, &expected_x);
+
+ let col = batches[0].column_by_name("y").unwrap();
+ let mut values = vec![];
+ values.append(vec![3; 200].as_mut());
+ values.append(vec![4; 300].as_mut());
+ let expected_y = Arc::new(Int64Array::from_iter_values(values)) as
ArrayRef;
+ assert_eq!(col, &expected_y);
+
+ let col = batches[0].column_by_name("z").unwrap();
+ let expected_z = Arc::new(Int64Array::from_iter_values(vec![4; 500]))
as ArrayRef;
+ assert_eq!(col, &expected_z);
+ }
+
+ #[tokio::test]
+ async fn test_filter_on_arrow_lt_or_gt() {
+ let mut fixture = TableTestFixture::new();
+ fixture.setup_manifest_files().await;
+
+ // Filter: y < 5 AND z >= 4
+ let mut builder = fixture.table.scan();
+ let predicate = Reference::new("y")
+ .less_than(Datum::long(5))
+ .or(Reference::new("z").greater_than_or_equal_to(Datum::long(4)));
+ builder = builder.filter(predicate);
+ let table_scan = builder.build().unwrap();
+
+ let batch_stream = table_scan.to_arrow().await.unwrap();
+
+ let batches: Vec<_> = batch_stream.try_collect().await.unwrap();
+ assert_eq!(batches[0].num_rows(), 1024);
+
+ let col = batches[0].column_by_name("x").unwrap();
+ let expected_x = Arc::new(Int64Array::from_iter_values(vec![1; 1024]))
as ArrayRef;
+ assert_eq!(col, &expected_x);
+
+ let col = batches[0].column_by_name("y").unwrap();
+ let mut values = vec![2; 512];
+ values.append(vec![3; 200].as_mut());
+ values.append(vec![4; 300].as_mut());
+ values.append(vec![5; 12].as_mut());
+ let expected_y = Arc::new(Int64Array::from_iter_values(values)) as
ArrayRef;
+ assert_eq!(col, &expected_y);
+
+ let col = batches[0].column_by_name("z").unwrap();
+ let mut values = vec![3; 512];
+ values.append(vec![4; 512].as_mut());
+ let expected_z = Arc::new(Int64Array::from_iter_values(values)) as
ArrayRef;
+ assert_eq!(col, &expected_z);
+ }
}