This is an automated email from the ASF dual-hosted git repository.
xuanwo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-rust.git
The following commit(s) were added to refs/heads/main by this push:
new 1533c437 feat (datafusion integration): convert datafusion expr
filters to Iceberg Predicate (#588)
1533c437 is described below
commit 1533c4376e751c82097134857c551bd78b8f7b64
Author: Alon Agmon <[email protected]>
AuthorDate: Mon Sep 23 13:39:46 2024 +0300
feat (datafusion integration): convert datafusion expr filters to Iceberg
Predicate (#588)
* adding main function and tests
* adding tests, removing integration test for now
* fixing typos and lints
* fixing typing issue
* - added support in schmema to convert Date32 to correct arrow type
- refactored scan to use new predicate converter as visitor and seperated
it to a new mod
- added support for simple predicates with column cast expressions
- added testing, mostly around date functions
* fixing format and lic
* reducing number of tests (17 -> 7)
* fix formats
* fix naming
* refactoring to use TreeNodeVisitor
* fixing fmt
* small refactor
* adding swapped op and fixing CR comments
---------
Co-authored-by: Alon Agmon <[email protected]>
---
crates/iceberg/src/arrow/schema.rs | 7 +-
.../src/physical_plan/expr_to_predicate.rs | 335 +++++++++++++++++++++
.../datafusion/src/physical_plan/mod.rs | 1 +
.../datafusion/src/physical_plan/scan.rs | 38 ++-
crates/integrations/datafusion/src/table.rs | 21 +-
5 files changed, 395 insertions(+), 7 deletions(-)
diff --git a/crates/iceberg/src/arrow/schema.rs
b/crates/iceberg/src/arrow/schema.rs
index 6c97621c..a32c10a2 100644
--- a/crates/iceberg/src/arrow/schema.rs
+++ b/crates/iceberg/src/arrow/schema.rs
@@ -24,8 +24,8 @@ use arrow_array::types::{
validate_decimal_precision_and_scale, Decimal128Type,
TimestampMicrosecondType,
};
use arrow_array::{
- BooleanArray, Datum as ArrowDatum, Float32Array, Float64Array, Int32Array,
Int64Array,
- PrimitiveArray, Scalar, StringArray, TimestampMicrosecondArray,
+ BooleanArray, Date32Array, Datum as ArrowDatum, Float32Array,
Float64Array, Int32Array,
+ Int64Array, PrimitiveArray, Scalar, StringArray, TimestampMicrosecondArray,
};
use arrow_schema::{DataType, Field, Fields, Schema as ArrowSchema, TimeUnit};
use bitvec::macros::internal::funty::Fundamental;
@@ -646,6 +646,9 @@ pub(crate) fn get_arrow_datum(datum: &Datum) ->
Result<Box<dyn ArrowDatum + Send
(PrimitiveType::String, PrimitiveLiteral::String(value)) => {
Ok(Box::new(StringArray::new_scalar(value.as_str())))
}
+ (PrimitiveType::Date, PrimitiveLiteral::Int(value)) => {
+ Ok(Box::new(Date32Array::new_scalar(*value)))
+ }
(PrimitiveType::Timestamp, PrimitiveLiteral::Long(value)) => {
Ok(Box::new(TimestampMicrosecondArray::new_scalar(*value)))
}
diff --git
a/crates/integrations/datafusion/src/physical_plan/expr_to_predicate.rs
b/crates/integrations/datafusion/src/physical_plan/expr_to_predicate.rs
new file mode 100644
index 00000000..110e4f7e
--- /dev/null
+++ b/crates/integrations/datafusion/src/physical_plan/expr_to_predicate.rs
@@ -0,0 +1,335 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::collections::VecDeque;
+
+use datafusion::common::tree_node::{TreeNodeRecursion, TreeNodeVisitor};
+use datafusion::common::Column;
+use datafusion::error::DataFusionError;
+use datafusion::logical_expr::{Expr, Operator};
+use datafusion::scalar::ScalarValue;
+use iceberg::expr::{Predicate, Reference};
+use iceberg::spec::Datum;
+
+pub struct ExprToPredicateVisitor {
+ stack: VecDeque<Option<Predicate>>,
+}
+impl ExprToPredicateVisitor {
+ /// Create a new predicate conversion visitor.
+ pub fn new() -> Self {
+ Self {
+ stack: VecDeque::new(),
+ }
+ }
+ /// Get the predicate from the stack.
+ pub fn get_predicate(&self) -> Option<Predicate> {
+ self.stack
+ .iter()
+ .filter_map(|opt| opt.clone())
+ .reduce(Predicate::and)
+ }
+
+ /// Convert a column expression to an iceberg predicate.
+ fn convert_column_expr(
+ &self,
+ col: &Column,
+ op: &Operator,
+ lit: &ScalarValue,
+ ) -> Option<Predicate> {
+ let reference = Reference::new(col.name.clone());
+ let datum = scalar_value_to_datum(lit)?;
+ Some(binary_op_to_predicate(reference, op, datum))
+ }
+
+ /// Convert a compound expression to an iceberg predicate.
+ ///
+ /// The strategy is to support the following cases:
+ /// - if its an AND expression then the result will be the valid
predicates, whether there are 2 or just 1
+ /// - if its an OR expression then a predicate will be returned only if
there are 2 valid predicates on both sides
+ fn convert_compound_expr(&self, valid_preds: &[Predicate], op: &Operator)
-> Option<Predicate> {
+ let valid_preds_count = valid_preds.len();
+ match (op, valid_preds_count) {
+ (Operator::And, 1) => valid_preds.first().cloned(),
+ (Operator::And, 2) => Some(Predicate::and(
+ valid_preds[0].clone(),
+ valid_preds[1].clone(),
+ )),
+ (Operator::Or, 2) => Some(Predicate::or(
+ valid_preds[0].clone(),
+ valid_preds[1].clone(),
+ )),
+ _ => None,
+ }
+ }
+}
+
+// Implement TreeNodeVisitor for ExprToPredicateVisitor
+impl<'n> TreeNodeVisitor<'n> for ExprToPredicateVisitor {
+ type Node = Expr;
+
+ fn f_down(&mut self, _node: &'n Expr) -> Result<TreeNodeRecursion,
DataFusionError> {
+ Ok(TreeNodeRecursion::Continue)
+ }
+
+ fn f_up(&mut self, expr: &'n Expr) -> Result<TreeNodeRecursion,
DataFusionError> {
+ if let Expr::BinaryExpr(binary) = expr {
+ match (&*binary.left, &binary.op, &*binary.right) {
+ // process simple binary expressions, e.g. col > 1
+ (Expr::Column(col), op, Expr::Literal(lit)) => {
+ let col_pred = self.convert_column_expr(col, op, lit);
+ self.stack.push_back(col_pred);
+ }
+ // // process reversed binary expressions, e.g. 1 < col
+ (Expr::Literal(lit), op, Expr::Column(col)) => {
+ let col_pred = op
+ .swap()
+ .and_then(|negated_op| self.convert_column_expr(col,
&negated_op, lit));
+ self.stack.push_back(col_pred);
+ }
+ // process compound expressions (involving logical operators.
e.g., AND or OR and children)
+ (_left, op, _right) if op.is_logic_operator() => {
+ let right_pred = self.stack.pop_back().flatten();
+ let left_pred = self.stack.pop_back().flatten();
+ let children: Vec<_> = [left_pred,
right_pred].into_iter().flatten().collect();
+ let compound_pred = self.convert_compound_expr(&children,
op);
+ self.stack.push_back(compound_pred);
+ }
+ _ => return Ok(TreeNodeRecursion::Continue),
+ }
+ }
+ Ok(TreeNodeRecursion::Continue)
+ }
+}
+
+const MILLIS_PER_DAY: i64 = 24 * 60 * 60 * 1000;
+/// Convert a scalar value to an iceberg datum.
+fn scalar_value_to_datum(value: &ScalarValue) -> Option<Datum> {
+ match value {
+ ScalarValue::Int8(Some(v)) => Some(Datum::int(*v as i32)),
+ ScalarValue::Int16(Some(v)) => Some(Datum::int(*v as i32)),
+ ScalarValue::Int32(Some(v)) => Some(Datum::int(*v)),
+ ScalarValue::Int64(Some(v)) => Some(Datum::long(*v)),
+ ScalarValue::Float32(Some(v)) => Some(Datum::double(*v as f64)),
+ ScalarValue::Float64(Some(v)) => Some(Datum::double(*v)),
+ ScalarValue::Utf8(Some(v)) => Some(Datum::string(v.clone())),
+ ScalarValue::LargeUtf8(Some(v)) => Some(Datum::string(v.clone())),
+ ScalarValue::Date32(Some(v)) => Some(Datum::date(*v)),
+ ScalarValue::Date64(Some(v)) => Some(Datum::date((*v / MILLIS_PER_DAY)
as i32)),
+ _ => None,
+ }
+}
+
+/// convert the data fusion Exp to an iceberg [`Predicate`]
+fn binary_op_to_predicate(reference: Reference, op: &Operator, datum: Datum)
-> Predicate {
+ match op {
+ Operator::Eq => reference.equal_to(datum),
+ Operator::NotEq => reference.not_equal_to(datum),
+ Operator::Lt => reference.less_than(datum),
+ Operator::LtEq => reference.less_than_or_equal_to(datum),
+ Operator::Gt => reference.greater_than(datum),
+ Operator::GtEq => reference.greater_than_or_equal_to(datum),
+ _ => Predicate::AlwaysTrue,
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use std::collections::VecDeque;
+
+ use datafusion::arrow::datatypes::{DataType, Field, Schema};
+ use datafusion::common::tree_node::TreeNode;
+ use datafusion::common::DFSchema;
+ use datafusion::prelude::SessionContext;
+ use iceberg::expr::{Predicate, Reference};
+ use iceberg::spec::Datum;
+
+ use super::ExprToPredicateVisitor;
+
+ fn create_test_schema() -> DFSchema {
+ let arrow_schema = Schema::new(vec![
+ Field::new("foo", DataType::Int32, false),
+ Field::new("bar", DataType::Utf8, false),
+ ]);
+ DFSchema::try_from_qualified_schema("my_table", &arrow_schema).unwrap()
+ }
+
+ #[test]
+ fn test_predicate_conversion_with_single_condition() {
+ let sql = "foo > 1";
+ let df_schema = create_test_schema();
+ let expr = SessionContext::new()
+ .parse_sql_expr(sql, &df_schema)
+ .unwrap();
+ let mut visitor = ExprToPredicateVisitor::new();
+ expr.visit(&mut visitor).unwrap();
+ let predicate = visitor.get_predicate().unwrap();
+ assert_eq!(
+ predicate,
+ Reference::new("foo").greater_than(Datum::long(1))
+ );
+ }
+ #[test]
+ fn test_predicate_conversion_with_single_unsupported_condition() {
+ let sql = "foo is null";
+ let df_schema = create_test_schema();
+ let expr = SessionContext::new()
+ .parse_sql_expr(sql, &df_schema)
+ .unwrap();
+ let mut visitor = ExprToPredicateVisitor::new();
+ expr.visit(&mut visitor).unwrap();
+ let predicate = visitor.get_predicate();
+ assert_eq!(predicate, None);
+ }
+
+ #[test]
+ fn test_predicate_conversion_with_single_condition_rev() {
+ let sql = "1 < foo";
+ let df_schema = create_test_schema();
+ let expr = SessionContext::new()
+ .parse_sql_expr(sql, &df_schema)
+ .unwrap();
+ let mut visitor = ExprToPredicateVisitor::new();
+ expr.visit(&mut visitor).unwrap();
+ let predicate = visitor.get_predicate().unwrap();
+ assert_eq!(
+ predicate,
+ Reference::new("foo").greater_than(Datum::long(1))
+ );
+ }
+ #[test]
+ fn test_predicate_conversion_with_and_condition() {
+ let sql = "foo > 1 and bar = 'test'";
+ let df_schema = create_test_schema();
+ let expr = SessionContext::new()
+ .parse_sql_expr(sql, &df_schema)
+ .unwrap();
+ let mut visitor = ExprToPredicateVisitor::new();
+ expr.visit(&mut visitor).unwrap();
+ let predicate = visitor.get_predicate().unwrap();
+ let expected_predicate = Predicate::and(
+ Reference::new("foo").greater_than(Datum::long(1)),
+ Reference::new("bar").equal_to(Datum::string("test")),
+ );
+ assert_eq!(predicate, expected_predicate);
+ }
+
+ #[test]
+ fn test_predicate_conversion_with_and_condition_unsupported() {
+ let sql = "foo > 1 and bar is not null";
+ let df_schema = create_test_schema();
+ let expr = SessionContext::new()
+ .parse_sql_expr(sql, &df_schema)
+ .unwrap();
+ let mut visitor = ExprToPredicateVisitor::new();
+ expr.visit(&mut visitor).unwrap();
+ let predicate = visitor.get_predicate().unwrap();
+ let expected_predicate =
Reference::new("foo").greater_than(Datum::long(1));
+ assert_eq!(predicate, expected_predicate);
+ }
+ #[test]
+ fn test_predicate_conversion_with_and_condition_both_unsupported() {
+ let sql = "foo in (1, 2, 3) and bar is not null";
+ let df_schema = create_test_schema();
+ let expr = SessionContext::new()
+ .parse_sql_expr(sql, &df_schema)
+ .unwrap();
+ let mut visitor = ExprToPredicateVisitor::new();
+ expr.visit(&mut visitor).unwrap();
+ let predicate = visitor.get_predicate();
+ let expected_predicate = None;
+ assert_eq!(predicate, expected_predicate);
+ }
+
+ #[test]
+ fn test_predicate_conversion_with_or_condition_unsupported() {
+ let sql = "foo > 1 or bar is not null";
+ let df_schema = create_test_schema();
+ let expr = SessionContext::new()
+ .parse_sql_expr(sql, &df_schema)
+ .unwrap();
+ let mut visitor = ExprToPredicateVisitor::new();
+ expr.visit(&mut visitor).unwrap();
+ let predicate = visitor.get_predicate();
+ let expected_predicate = None;
+ assert_eq!(predicate, expected_predicate);
+ }
+
+ #[test]
+ fn test_predicate_conversion_with_complex_binary_expr() {
+ let sql = "(foo > 1 and bar = 'test') or foo < 0 ";
+ let df_schema = create_test_schema();
+ let expr = SessionContext::new()
+ .parse_sql_expr(sql, &df_schema)
+ .unwrap();
+ let mut visitor = ExprToPredicateVisitor::new();
+ expr.visit(&mut visitor).unwrap();
+ let predicate = visitor.get_predicate().unwrap();
+ let inner_predicate = Predicate::and(
+ Reference::new("foo").greater_than(Datum::long(1)),
+ Reference::new("bar").equal_to(Datum::string("test")),
+ );
+ let expected_predicate = Predicate::or(
+ inner_predicate,
+ Reference::new("foo").less_than(Datum::long(0)),
+ );
+ assert_eq!(predicate, expected_predicate);
+ }
+
+ #[test]
+ fn test_predicate_conversion_with_complex_binary_expr_unsupported() {
+ let sql = "(foo > 1 or bar in ('test', 'test2')) and foo < 0 ";
+ let df_schema = create_test_schema();
+ let expr = SessionContext::new()
+ .parse_sql_expr(sql, &df_schema)
+ .unwrap();
+ let mut visitor = ExprToPredicateVisitor::new();
+ expr.visit(&mut visitor).unwrap();
+ let predicate = visitor.get_predicate().unwrap();
+ let expected_predicate =
Reference::new("foo").less_than(Datum::long(0));
+ assert_eq!(predicate, expected_predicate);
+ }
+
+ #[test]
+ // test the get result method
+ fn test_get_result_multiple() {
+ let predicates = vec![
+ Some(Reference::new("foo").greater_than(Datum::long(1))),
+ None,
+ Some(Reference::new("bar").equal_to(Datum::string("test"))),
+ ];
+ let stack = VecDeque::from(predicates);
+ let visitor = ExprToPredicateVisitor { stack };
+ assert_eq!(
+ visitor.get_predicate(),
+ Some(Predicate::and(
+ Reference::new("foo").greater_than(Datum::long(1)),
+ Reference::new("bar").equal_to(Datum::string("test")),
+ ))
+ );
+ }
+
+ #[test]
+ fn test_get_result_single() {
+ let predicates =
vec![Some(Reference::new("foo").greater_than(Datum::long(1)))];
+ let stack = VecDeque::from(predicates);
+ let visitor = ExprToPredicateVisitor { stack };
+ assert_eq!(
+ visitor.get_predicate(),
+ Some(Reference::new("foo").greater_than(Datum::long(1)))
+ );
+ }
+}
diff --git a/crates/integrations/datafusion/src/physical_plan/mod.rs
b/crates/integrations/datafusion/src/physical_plan/mod.rs
index 5ae586a0..2fab109d 100644
--- a/crates/integrations/datafusion/src/physical_plan/mod.rs
+++ b/crates/integrations/datafusion/src/physical_plan/mod.rs
@@ -15,4 +15,5 @@
// specific language governing permissions and limitations
// under the License.
+pub(crate) mod expr_to_predicate;
pub(crate) mod scan;
diff --git a/crates/integrations/datafusion/src/physical_plan/scan.rs
b/crates/integrations/datafusion/src/physical_plan/scan.rs
index 576acea6..c53ce76d 100644
--- a/crates/integrations/datafusion/src/physical_plan/scan.rs
+++ b/crates/integrations/datafusion/src/physical_plan/scan.rs
@@ -22,6 +22,7 @@ use std::vec;
use datafusion::arrow::array::RecordBatch;
use datafusion::arrow::datatypes::SchemaRef as ArrowSchemaRef;
+use datafusion::common::tree_node::TreeNode;
use datafusion::error::Result as DFResult;
use datafusion::execution::{SendableRecordBatchStream, TaskContext};
use datafusion::physical_expr::EquivalenceProperties;
@@ -29,9 +30,12 @@ use
datafusion::physical_plan::stream::RecordBatchStreamAdapter;
use datafusion::physical_plan::{
DisplayAs, ExecutionMode, ExecutionPlan, Partitioning, PlanProperties,
};
+use datafusion::prelude::Expr;
use futures::{Stream, TryStreamExt};
+use iceberg::expr::Predicate;
use iceberg::table::Table;
+use crate::physical_plan::expr_to_predicate::ExprToPredicateVisitor;
use crate::to_datafusion_error;
/// Manages the scanning process of an Iceberg [`Table`], encapsulating the
@@ -47,6 +51,8 @@ pub(crate) struct IcebergTableScan {
plan_properties: PlanProperties,
/// Projection column names, None means all columns
projection: Option<Vec<String>>,
+ /// Filters to apply to the table scan
+ predicates: Option<Predicate>,
}
impl IcebergTableScan {
@@ -55,15 +61,18 @@ impl IcebergTableScan {
table: Table,
schema: ArrowSchemaRef,
projection: Option<&Vec<usize>>,
+ filters: &[Expr],
) -> Self {
let plan_properties = Self::compute_properties(schema.clone());
let projection = get_column_names(schema.clone(), projection);
+ let predicates = convert_filters_to_predicate(filters);
Self {
table,
schema,
plan_properties,
projection,
+ predicates,
}
}
@@ -109,7 +118,11 @@ impl ExecutionPlan for IcebergTableScan {
_partition: usize,
_context: Arc<TaskContext>,
) -> DFResult<SendableRecordBatchStream> {
- let fut = get_batch_stream(self.table.clone(),
self.projection.clone());
+ let fut = get_batch_stream(
+ self.table.clone(),
+ self.projection.clone(),
+ self.predicates.clone(),
+ );
let stream = futures::stream::once(fut).try_flatten();
Ok(Box::pin(RecordBatchStreamAdapter::new(
@@ -143,11 +156,15 @@ impl DisplayAs for IcebergTableScan {
async fn get_batch_stream(
table: Table,
column_names: Option<Vec<String>>,
+ predicates: Option<Predicate>,
) -> DFResult<Pin<Box<dyn Stream<Item = DFResult<RecordBatch>> + Send>>> {
- let scan_builder = match column_names {
+ let mut scan_builder = match column_names {
Some(column_names) => table.scan().select(column_names),
None => table.scan().select_all(),
};
+ if let Some(pred) = predicates {
+ scan_builder = scan_builder.with_filter(pred);
+ }
let table_scan = scan_builder.build().map_err(to_datafusion_error)?;
let stream = table_scan
@@ -155,10 +172,25 @@ async fn get_batch_stream(
.await
.map_err(to_datafusion_error)?
.map_err(to_datafusion_error);
-
Ok(Box::pin(stream))
}
+/// Converts DataFusion filters ([`Expr`]) to an iceberg [`Predicate`].
+/// If none of the filters could be converted, return `None` which adds no
predicates to the scan operation.
+/// If the conversion was successful, return the converted predicates combined
with an AND operator.
+fn convert_filters_to_predicate(filters: &[Expr]) -> Option<Predicate> {
+ filters
+ .iter()
+ .filter_map(|expr| {
+ let mut visitor = ExprToPredicateVisitor::new();
+ if expr.visit(&mut visitor).is_ok() {
+ visitor.get_predicate()
+ } else {
+ None
+ }
+ })
+ .reduce(Predicate::and)
+}
fn get_column_names(
schema: ArrowSchemaRef,
projection: Option<&Vec<usize>>,
diff --git a/crates/integrations/datafusion/src/table.rs
b/crates/integrations/datafusion/src/table.rs
index 8d70d948..016c6c00 100644
--- a/crates/integrations/datafusion/src/table.rs
+++ b/crates/integrations/datafusion/src/table.rs
@@ -23,7 +23,7 @@ use datafusion::arrow::datatypes::SchemaRef as ArrowSchemaRef;
use datafusion::catalog::Session;
use datafusion::datasource::{TableProvider, TableType};
use datafusion::error::Result as DFResult;
-use datafusion::logical_expr::Expr;
+use datafusion::logical_expr::{BinaryExpr, Expr, TableProviderFilterPushDown};
use datafusion::physical_plan::ExecutionPlan;
use iceberg::arrow::schema_to_arrow_schema;
use iceberg::table::Table;
@@ -76,13 +76,30 @@ impl TableProvider for IcebergTableProvider {
&self,
_state: &dyn Session,
projection: Option<&Vec<usize>>,
- _filters: &[Expr],
+ filters: &[Expr],
_limit: Option<usize>,
) -> DFResult<Arc<dyn ExecutionPlan>> {
Ok(Arc::new(IcebergTableScan::new(
self.table.clone(),
self.schema.clone(),
projection,
+ filters,
)))
}
+
+ fn supports_filters_pushdown(
+ &self,
+ filters: &[&Expr],
+ ) -> std::result::Result<Vec<TableProviderFilterPushDown>,
datafusion::error::DataFusionError>
+ {
+ let filter_support = filters
+ .iter()
+ .map(|e| match e {
+ Expr::BinaryExpr(BinaryExpr { .. }) =>
TableProviderFilterPushDown::Inexact,
+ _ => TableProviderFilterPushDown::Unsupported,
+ })
+ .collect::<Vec<TableProviderFilterPushDown>>();
+
+ Ok(filter_support)
+ }
}