This is an automated email from the ASF dual-hosted git repository.

xuanwo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-rust.git


The following commit(s) were added to refs/heads/main by this push:
     new 1533c437 feat (datafusion integration): convert datafusion expr 
filters to Iceberg Predicate (#588)
1533c437 is described below

commit 1533c4376e751c82097134857c551bd78b8f7b64
Author: Alon Agmon <[email protected]>
AuthorDate: Mon Sep 23 13:39:46 2024 +0300

    feat (datafusion integration): convert datafusion expr filters to Iceberg 
Predicate (#588)
    
    * adding main function and tests
    
    * adding tests, removing integration test for now
    
    * fixing typos and lints
    
    * fixing typing issue
    
    * - added support in schmema to convert Date32 to correct arrow type
    - refactored scan to use new predicate converter as visitor and seperated 
it to a new mod
    - added support for simple predicates with column cast expressions
    - added testing, mostly around date functions
    
    * fixing format and lic
    
    * reducing number of tests (17 -> 7)
    
    * fix formats
    
    * fix naming
    
    * refactoring to use TreeNodeVisitor
    
    * fixing fmt
    
    * small refactor
    
    * adding swapped op and fixing CR comments
    
    ---------
    
    Co-authored-by: Alon Agmon <[email protected]>
---
 crates/iceberg/src/arrow/schema.rs                 |   7 +-
 .../src/physical_plan/expr_to_predicate.rs         | 335 +++++++++++++++++++++
 .../datafusion/src/physical_plan/mod.rs            |   1 +
 .../datafusion/src/physical_plan/scan.rs           |  38 ++-
 crates/integrations/datafusion/src/table.rs        |  21 +-
 5 files changed, 395 insertions(+), 7 deletions(-)

diff --git a/crates/iceberg/src/arrow/schema.rs 
b/crates/iceberg/src/arrow/schema.rs
index 6c97621c..a32c10a2 100644
--- a/crates/iceberg/src/arrow/schema.rs
+++ b/crates/iceberg/src/arrow/schema.rs
@@ -24,8 +24,8 @@ use arrow_array::types::{
     validate_decimal_precision_and_scale, Decimal128Type, 
TimestampMicrosecondType,
 };
 use arrow_array::{
-    BooleanArray, Datum as ArrowDatum, Float32Array, Float64Array, Int32Array, 
Int64Array,
-    PrimitiveArray, Scalar, StringArray, TimestampMicrosecondArray,
+    BooleanArray, Date32Array, Datum as ArrowDatum, Float32Array, 
Float64Array, Int32Array,
+    Int64Array, PrimitiveArray, Scalar, StringArray, TimestampMicrosecondArray,
 };
 use arrow_schema::{DataType, Field, Fields, Schema as ArrowSchema, TimeUnit};
 use bitvec::macros::internal::funty::Fundamental;
@@ -646,6 +646,9 @@ pub(crate) fn get_arrow_datum(datum: &Datum) -> 
Result<Box<dyn ArrowDatum + Send
         (PrimitiveType::String, PrimitiveLiteral::String(value)) => {
             Ok(Box::new(StringArray::new_scalar(value.as_str())))
         }
+        (PrimitiveType::Date, PrimitiveLiteral::Int(value)) => {
+            Ok(Box::new(Date32Array::new_scalar(*value)))
+        }
         (PrimitiveType::Timestamp, PrimitiveLiteral::Long(value)) => {
             Ok(Box::new(TimestampMicrosecondArray::new_scalar(*value)))
         }
diff --git 
a/crates/integrations/datafusion/src/physical_plan/expr_to_predicate.rs 
b/crates/integrations/datafusion/src/physical_plan/expr_to_predicate.rs
new file mode 100644
index 00000000..110e4f7e
--- /dev/null
+++ b/crates/integrations/datafusion/src/physical_plan/expr_to_predicate.rs
@@ -0,0 +1,335 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::collections::VecDeque;
+
+use datafusion::common::tree_node::{TreeNodeRecursion, TreeNodeVisitor};
+use datafusion::common::Column;
+use datafusion::error::DataFusionError;
+use datafusion::logical_expr::{Expr, Operator};
+use datafusion::scalar::ScalarValue;
+use iceberg::expr::{Predicate, Reference};
+use iceberg::spec::Datum;
+
+pub struct ExprToPredicateVisitor {
+    stack: VecDeque<Option<Predicate>>,
+}
+impl ExprToPredicateVisitor {
+    /// Create a new predicate conversion visitor.
+    pub fn new() -> Self {
+        Self {
+            stack: VecDeque::new(),
+        }
+    }
+    /// Get the predicate from the stack.
+    pub fn get_predicate(&self) -> Option<Predicate> {
+        self.stack
+            .iter()
+            .filter_map(|opt| opt.clone())
+            .reduce(Predicate::and)
+    }
+
+    /// Convert a column expression to an iceberg predicate.
+    fn convert_column_expr(
+        &self,
+        col: &Column,
+        op: &Operator,
+        lit: &ScalarValue,
+    ) -> Option<Predicate> {
+        let reference = Reference::new(col.name.clone());
+        let datum = scalar_value_to_datum(lit)?;
+        Some(binary_op_to_predicate(reference, op, datum))
+    }
+
+    /// Convert a compound expression to an iceberg predicate.
+    ///
+    /// The strategy is to support the following cases:
+    /// - if its an AND expression then the result will be the valid 
predicates, whether there are 2 or just 1
+    /// - if its an OR expression then a predicate will be returned only if 
there are 2 valid predicates on both sides
+    fn convert_compound_expr(&self, valid_preds: &[Predicate], op: &Operator) 
-> Option<Predicate> {
+        let valid_preds_count = valid_preds.len();
+        match (op, valid_preds_count) {
+            (Operator::And, 1) => valid_preds.first().cloned(),
+            (Operator::And, 2) => Some(Predicate::and(
+                valid_preds[0].clone(),
+                valid_preds[1].clone(),
+            )),
+            (Operator::Or, 2) => Some(Predicate::or(
+                valid_preds[0].clone(),
+                valid_preds[1].clone(),
+            )),
+            _ => None,
+        }
+    }
+}
+
+// Implement TreeNodeVisitor for ExprToPredicateVisitor
+impl<'n> TreeNodeVisitor<'n> for ExprToPredicateVisitor {
+    type Node = Expr;
+
+    fn f_down(&mut self, _node: &'n Expr) -> Result<TreeNodeRecursion, 
DataFusionError> {
+        Ok(TreeNodeRecursion::Continue)
+    }
+
+    fn f_up(&mut self, expr: &'n Expr) -> Result<TreeNodeRecursion, 
DataFusionError> {
+        if let Expr::BinaryExpr(binary) = expr {
+            match (&*binary.left, &binary.op, &*binary.right) {
+                // process simple binary expressions, e.g. col > 1
+                (Expr::Column(col), op, Expr::Literal(lit)) => {
+                    let col_pred = self.convert_column_expr(col, op, lit);
+                    self.stack.push_back(col_pred);
+                }
+                // // process reversed binary expressions, e.g. 1 < col
+                (Expr::Literal(lit), op, Expr::Column(col)) => {
+                    let col_pred = op
+                        .swap()
+                        .and_then(|negated_op| self.convert_column_expr(col, 
&negated_op, lit));
+                    self.stack.push_back(col_pred);
+                }
+                // process compound expressions (involving logical operators. 
e.g., AND or OR and children)
+                (_left, op, _right) if op.is_logic_operator() => {
+                    let right_pred = self.stack.pop_back().flatten();
+                    let left_pred = self.stack.pop_back().flatten();
+                    let children: Vec<_> = [left_pred, 
right_pred].into_iter().flatten().collect();
+                    let compound_pred = self.convert_compound_expr(&children, 
op);
+                    self.stack.push_back(compound_pred);
+                }
+                _ => return Ok(TreeNodeRecursion::Continue),
+            }
+        }
+        Ok(TreeNodeRecursion::Continue)
+    }
+}
+
+const MILLIS_PER_DAY: i64 = 24 * 60 * 60 * 1000;
+/// Convert a scalar value to an iceberg datum.
+fn scalar_value_to_datum(value: &ScalarValue) -> Option<Datum> {
+    match value {
+        ScalarValue::Int8(Some(v)) => Some(Datum::int(*v as i32)),
+        ScalarValue::Int16(Some(v)) => Some(Datum::int(*v as i32)),
+        ScalarValue::Int32(Some(v)) => Some(Datum::int(*v)),
+        ScalarValue::Int64(Some(v)) => Some(Datum::long(*v)),
+        ScalarValue::Float32(Some(v)) => Some(Datum::double(*v as f64)),
+        ScalarValue::Float64(Some(v)) => Some(Datum::double(*v)),
+        ScalarValue::Utf8(Some(v)) => Some(Datum::string(v.clone())),
+        ScalarValue::LargeUtf8(Some(v)) => Some(Datum::string(v.clone())),
+        ScalarValue::Date32(Some(v)) => Some(Datum::date(*v)),
+        ScalarValue::Date64(Some(v)) => Some(Datum::date((*v / MILLIS_PER_DAY) 
as i32)),
+        _ => None,
+    }
+}
+
+/// convert the data fusion Exp to an iceberg [`Predicate`]
+fn binary_op_to_predicate(reference: Reference, op: &Operator, datum: Datum) 
-> Predicate {
+    match op {
+        Operator::Eq => reference.equal_to(datum),
+        Operator::NotEq => reference.not_equal_to(datum),
+        Operator::Lt => reference.less_than(datum),
+        Operator::LtEq => reference.less_than_or_equal_to(datum),
+        Operator::Gt => reference.greater_than(datum),
+        Operator::GtEq => reference.greater_than_or_equal_to(datum),
+        _ => Predicate::AlwaysTrue,
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use std::collections::VecDeque;
+
+    use datafusion::arrow::datatypes::{DataType, Field, Schema};
+    use datafusion::common::tree_node::TreeNode;
+    use datafusion::common::DFSchema;
+    use datafusion::prelude::SessionContext;
+    use iceberg::expr::{Predicate, Reference};
+    use iceberg::spec::Datum;
+
+    use super::ExprToPredicateVisitor;
+
+    fn create_test_schema() -> DFSchema {
+        let arrow_schema = Schema::new(vec![
+            Field::new("foo", DataType::Int32, false),
+            Field::new("bar", DataType::Utf8, false),
+        ]);
+        DFSchema::try_from_qualified_schema("my_table", &arrow_schema).unwrap()
+    }
+
+    #[test]
+    fn test_predicate_conversion_with_single_condition() {
+        let sql = "foo > 1";
+        let df_schema = create_test_schema();
+        let expr = SessionContext::new()
+            .parse_sql_expr(sql, &df_schema)
+            .unwrap();
+        let mut visitor = ExprToPredicateVisitor::new();
+        expr.visit(&mut visitor).unwrap();
+        let predicate = visitor.get_predicate().unwrap();
+        assert_eq!(
+            predicate,
+            Reference::new("foo").greater_than(Datum::long(1))
+        );
+    }
+    #[test]
+    fn test_predicate_conversion_with_single_unsupported_condition() {
+        let sql = "foo is null";
+        let df_schema = create_test_schema();
+        let expr = SessionContext::new()
+            .parse_sql_expr(sql, &df_schema)
+            .unwrap();
+        let mut visitor = ExprToPredicateVisitor::new();
+        expr.visit(&mut visitor).unwrap();
+        let predicate = visitor.get_predicate();
+        assert_eq!(predicate, None);
+    }
+
+    #[test]
+    fn test_predicate_conversion_with_single_condition_rev() {
+        let sql = "1 < foo";
+        let df_schema = create_test_schema();
+        let expr = SessionContext::new()
+            .parse_sql_expr(sql, &df_schema)
+            .unwrap();
+        let mut visitor = ExprToPredicateVisitor::new();
+        expr.visit(&mut visitor).unwrap();
+        let predicate = visitor.get_predicate().unwrap();
+        assert_eq!(
+            predicate,
+            Reference::new("foo").greater_than(Datum::long(1))
+        );
+    }
+    #[test]
+    fn test_predicate_conversion_with_and_condition() {
+        let sql = "foo > 1 and bar = 'test'";
+        let df_schema = create_test_schema();
+        let expr = SessionContext::new()
+            .parse_sql_expr(sql, &df_schema)
+            .unwrap();
+        let mut visitor = ExprToPredicateVisitor::new();
+        expr.visit(&mut visitor).unwrap();
+        let predicate = visitor.get_predicate().unwrap();
+        let expected_predicate = Predicate::and(
+            Reference::new("foo").greater_than(Datum::long(1)),
+            Reference::new("bar").equal_to(Datum::string("test")),
+        );
+        assert_eq!(predicate, expected_predicate);
+    }
+
+    #[test]
+    fn test_predicate_conversion_with_and_condition_unsupported() {
+        let sql = "foo > 1 and bar is not null";
+        let df_schema = create_test_schema();
+        let expr = SessionContext::new()
+            .parse_sql_expr(sql, &df_schema)
+            .unwrap();
+        let mut visitor = ExprToPredicateVisitor::new();
+        expr.visit(&mut visitor).unwrap();
+        let predicate = visitor.get_predicate().unwrap();
+        let expected_predicate = 
Reference::new("foo").greater_than(Datum::long(1));
+        assert_eq!(predicate, expected_predicate);
+    }
+    #[test]
+    fn test_predicate_conversion_with_and_condition_both_unsupported() {
+        let sql = "foo in (1, 2, 3) and bar is not null";
+        let df_schema = create_test_schema();
+        let expr = SessionContext::new()
+            .parse_sql_expr(sql, &df_schema)
+            .unwrap();
+        let mut visitor = ExprToPredicateVisitor::new();
+        expr.visit(&mut visitor).unwrap();
+        let predicate = visitor.get_predicate();
+        let expected_predicate = None;
+        assert_eq!(predicate, expected_predicate);
+    }
+
+    #[test]
+    fn test_predicate_conversion_with_or_condition_unsupported() {
+        let sql = "foo > 1 or bar is not null";
+        let df_schema = create_test_schema();
+        let expr = SessionContext::new()
+            .parse_sql_expr(sql, &df_schema)
+            .unwrap();
+        let mut visitor = ExprToPredicateVisitor::new();
+        expr.visit(&mut visitor).unwrap();
+        let predicate = visitor.get_predicate();
+        let expected_predicate = None;
+        assert_eq!(predicate, expected_predicate);
+    }
+
+    #[test]
+    fn test_predicate_conversion_with_complex_binary_expr() {
+        let sql = "(foo > 1 and bar = 'test') or foo < 0 ";
+        let df_schema = create_test_schema();
+        let expr = SessionContext::new()
+            .parse_sql_expr(sql, &df_schema)
+            .unwrap();
+        let mut visitor = ExprToPredicateVisitor::new();
+        expr.visit(&mut visitor).unwrap();
+        let predicate = visitor.get_predicate().unwrap();
+        let inner_predicate = Predicate::and(
+            Reference::new("foo").greater_than(Datum::long(1)),
+            Reference::new("bar").equal_to(Datum::string("test")),
+        );
+        let expected_predicate = Predicate::or(
+            inner_predicate,
+            Reference::new("foo").less_than(Datum::long(0)),
+        );
+        assert_eq!(predicate, expected_predicate);
+    }
+
+    #[test]
+    fn test_predicate_conversion_with_complex_binary_expr_unsupported() {
+        let sql = "(foo > 1 or bar in ('test', 'test2')) and foo < 0 ";
+        let df_schema = create_test_schema();
+        let expr = SessionContext::new()
+            .parse_sql_expr(sql, &df_schema)
+            .unwrap();
+        let mut visitor = ExprToPredicateVisitor::new();
+        expr.visit(&mut visitor).unwrap();
+        let predicate = visitor.get_predicate().unwrap();
+        let expected_predicate = 
Reference::new("foo").less_than(Datum::long(0));
+        assert_eq!(predicate, expected_predicate);
+    }
+
+    #[test]
+    // test the get result method
+    fn test_get_result_multiple() {
+        let predicates = vec![
+            Some(Reference::new("foo").greater_than(Datum::long(1))),
+            None,
+            Some(Reference::new("bar").equal_to(Datum::string("test"))),
+        ];
+        let stack = VecDeque::from(predicates);
+        let visitor = ExprToPredicateVisitor { stack };
+        assert_eq!(
+            visitor.get_predicate(),
+            Some(Predicate::and(
+                Reference::new("foo").greater_than(Datum::long(1)),
+                Reference::new("bar").equal_to(Datum::string("test")),
+            ))
+        );
+    }
+
+    #[test]
+    fn test_get_result_single() {
+        let predicates = 
vec![Some(Reference::new("foo").greater_than(Datum::long(1)))];
+        let stack = VecDeque::from(predicates);
+        let visitor = ExprToPredicateVisitor { stack };
+        assert_eq!(
+            visitor.get_predicate(),
+            Some(Reference::new("foo").greater_than(Datum::long(1)))
+        );
+    }
+}
diff --git a/crates/integrations/datafusion/src/physical_plan/mod.rs 
b/crates/integrations/datafusion/src/physical_plan/mod.rs
index 5ae586a0..2fab109d 100644
--- a/crates/integrations/datafusion/src/physical_plan/mod.rs
+++ b/crates/integrations/datafusion/src/physical_plan/mod.rs
@@ -15,4 +15,5 @@
 // specific language governing permissions and limitations
 // under the License.
 
+pub(crate) mod expr_to_predicate;
 pub(crate) mod scan;
diff --git a/crates/integrations/datafusion/src/physical_plan/scan.rs 
b/crates/integrations/datafusion/src/physical_plan/scan.rs
index 576acea6..c53ce76d 100644
--- a/crates/integrations/datafusion/src/physical_plan/scan.rs
+++ b/crates/integrations/datafusion/src/physical_plan/scan.rs
@@ -22,6 +22,7 @@ use std::vec;
 
 use datafusion::arrow::array::RecordBatch;
 use datafusion::arrow::datatypes::SchemaRef as ArrowSchemaRef;
+use datafusion::common::tree_node::TreeNode;
 use datafusion::error::Result as DFResult;
 use datafusion::execution::{SendableRecordBatchStream, TaskContext};
 use datafusion::physical_expr::EquivalenceProperties;
@@ -29,9 +30,12 @@ use 
datafusion::physical_plan::stream::RecordBatchStreamAdapter;
 use datafusion::physical_plan::{
     DisplayAs, ExecutionMode, ExecutionPlan, Partitioning, PlanProperties,
 };
+use datafusion::prelude::Expr;
 use futures::{Stream, TryStreamExt};
+use iceberg::expr::Predicate;
 use iceberg::table::Table;
 
+use crate::physical_plan::expr_to_predicate::ExprToPredicateVisitor;
 use crate::to_datafusion_error;
 
 /// Manages the scanning process of an Iceberg [`Table`], encapsulating the
@@ -47,6 +51,8 @@ pub(crate) struct IcebergTableScan {
     plan_properties: PlanProperties,
     /// Projection column names, None means all columns
     projection: Option<Vec<String>>,
+    /// Filters to apply to the table scan
+    predicates: Option<Predicate>,
 }
 
 impl IcebergTableScan {
@@ -55,15 +61,18 @@ impl IcebergTableScan {
         table: Table,
         schema: ArrowSchemaRef,
         projection: Option<&Vec<usize>>,
+        filters: &[Expr],
     ) -> Self {
         let plan_properties = Self::compute_properties(schema.clone());
         let projection = get_column_names(schema.clone(), projection);
+        let predicates = convert_filters_to_predicate(filters);
 
         Self {
             table,
             schema,
             plan_properties,
             projection,
+            predicates,
         }
     }
 
@@ -109,7 +118,11 @@ impl ExecutionPlan for IcebergTableScan {
         _partition: usize,
         _context: Arc<TaskContext>,
     ) -> DFResult<SendableRecordBatchStream> {
-        let fut = get_batch_stream(self.table.clone(), 
self.projection.clone());
+        let fut = get_batch_stream(
+            self.table.clone(),
+            self.projection.clone(),
+            self.predicates.clone(),
+        );
         let stream = futures::stream::once(fut).try_flatten();
 
         Ok(Box::pin(RecordBatchStreamAdapter::new(
@@ -143,11 +156,15 @@ impl DisplayAs for IcebergTableScan {
 async fn get_batch_stream(
     table: Table,
     column_names: Option<Vec<String>>,
+    predicates: Option<Predicate>,
 ) -> DFResult<Pin<Box<dyn Stream<Item = DFResult<RecordBatch>> + Send>>> {
-    let scan_builder = match column_names {
+    let mut scan_builder = match column_names {
         Some(column_names) => table.scan().select(column_names),
         None => table.scan().select_all(),
     };
+    if let Some(pred) = predicates {
+        scan_builder = scan_builder.with_filter(pred);
+    }
     let table_scan = scan_builder.build().map_err(to_datafusion_error)?;
 
     let stream = table_scan
@@ -155,10 +172,25 @@ async fn get_batch_stream(
         .await
         .map_err(to_datafusion_error)?
         .map_err(to_datafusion_error);
-
     Ok(Box::pin(stream))
 }
 
+/// Converts DataFusion filters ([`Expr`]) to an iceberg [`Predicate`].
+/// If none of the filters could be converted, return `None` which adds no 
predicates to the scan operation.
+/// If the conversion was successful, return the converted predicates combined 
with an AND operator.
+fn convert_filters_to_predicate(filters: &[Expr]) -> Option<Predicate> {
+    filters
+        .iter()
+        .filter_map(|expr| {
+            let mut visitor = ExprToPredicateVisitor::new();
+            if expr.visit(&mut visitor).is_ok() {
+                visitor.get_predicate()
+            } else {
+                None
+            }
+        })
+        .reduce(Predicate::and)
+}
 fn get_column_names(
     schema: ArrowSchemaRef,
     projection: Option<&Vec<usize>>,
diff --git a/crates/integrations/datafusion/src/table.rs 
b/crates/integrations/datafusion/src/table.rs
index 8d70d948..016c6c00 100644
--- a/crates/integrations/datafusion/src/table.rs
+++ b/crates/integrations/datafusion/src/table.rs
@@ -23,7 +23,7 @@ use datafusion::arrow::datatypes::SchemaRef as ArrowSchemaRef;
 use datafusion::catalog::Session;
 use datafusion::datasource::{TableProvider, TableType};
 use datafusion::error::Result as DFResult;
-use datafusion::logical_expr::Expr;
+use datafusion::logical_expr::{BinaryExpr, Expr, TableProviderFilterPushDown};
 use datafusion::physical_plan::ExecutionPlan;
 use iceberg::arrow::schema_to_arrow_schema;
 use iceberg::table::Table;
@@ -76,13 +76,30 @@ impl TableProvider for IcebergTableProvider {
         &self,
         _state: &dyn Session,
         projection: Option<&Vec<usize>>,
-        _filters: &[Expr],
+        filters: &[Expr],
         _limit: Option<usize>,
     ) -> DFResult<Arc<dyn ExecutionPlan>> {
         Ok(Arc::new(IcebergTableScan::new(
             self.table.clone(),
             self.schema.clone(),
             projection,
+            filters,
         )))
     }
+
+    fn supports_filters_pushdown(
+        &self,
+        filters: &[&Expr],
+    ) -> std::result::Result<Vec<TableProviderFilterPushDown>, 
datafusion::error::DataFusionError>
+    {
+        let filter_support = filters
+            .iter()
+            .map(|e| match e {
+                Expr::BinaryExpr(BinaryExpr { .. }) => 
TableProviderFilterPushDown::Inexact,
+                _ => TableProviderFilterPushDown::Unsupported,
+            })
+            .collect::<Vec<TableProviderFilterPushDown>>();
+
+        Ok(filter_support)
+    }
 }

Reply via email to