This is an automated email from the ASF dual-hosted git repository.

xushiyan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/hudi-rs.git


The following commit(s) were added to refs/heads/main by this push:
     new d04b9ab  feat: support row filters for `FileGroupReader` (#237)
d04b9ab is described below

commit d04b9ab44a424477ced07406be0e74a9baa3139a
Author: Shiyan Xu <[email protected]>
AuthorDate: Mon Jan 6 23:51:24 2025 -0600

    feat: support row filters for `FileGroupReader` (#237)
---
 crates/core/src/error.rs             |   4 +-
 crates/core/src/expr/filter.rs       | 152 ++++++++++++++++++++++++++++++-
 crates/core/src/file_group/reader.rs | 167 ++++++++++++++++++++++++++++++++---
 crates/core/src/table/partition.rs   |  14 +--
 4 files changed, 310 insertions(+), 27 deletions(-)

diff --git a/crates/core/src/error.rs b/crates/core/src/error.rs
index ae7ef2d..55aa8e8 100644
--- a/crates/core/src/error.rs
+++ b/crates/core/src/error.rs
@@ -36,8 +36,8 @@ pub enum CoreError {
     #[error("File group error: {0}")]
     FileGroup(String),
 
-    #[error("{0}: {1:?}")]
-    ReadFileSliceError(String, StorageError),
+    #[error("{0}")]
+    ReadFileSliceError(String),
 
     #[error("{0}")]
     InvalidPartitionPath(String),
diff --git a/crates/core/src/expr/filter.rs b/crates/core/src/expr/filter.rs
index 6be42f9..685aec4 100644
--- a/crates/core/src/expr/filter.rs
+++ b/crates/core/src/expr/filter.rs
@@ -20,8 +20,9 @@
 use crate::error::CoreError;
 use crate::expr::ExprOperator;
 use crate::Result;
-use arrow_array::{ArrayRef, Scalar, StringArray};
+use arrow_array::{ArrayRef, BooleanArray, Datum, Scalar, StringArray};
 use arrow_cast::{cast_with_options, CastOptions};
+use arrow_ord::cmp::{eq, gt, gt_eq, lt, lt_eq, neq};
 use arrow_schema::{DataType, Field, Schema};
 use std::str::FromStr;
 
@@ -171,4 +172,153 @@ impl SchemableFilter {
             })?,
         ))
     }
+
+    pub fn apply_comparsion(&self, value: &dyn Datum) -> Result<BooleanArray> {
+        match self.operator {
+            ExprOperator::Eq => eq(value, &self.value),
+            ExprOperator::Ne => neq(value, &self.value),
+            ExprOperator::Lt => lt(value, &self.value),
+            ExprOperator::Lte => lt_eq(value, &self.value),
+            ExprOperator::Gt => gt(value, &self.value),
+            ExprOperator::Gte => gt_eq(value, &self.value),
+        }
+        .map_err(|e| e.into())
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use arrow_array::{Int64Array, StringArray};
+    use arrow_schema::{DataType, Field, Schema};
+
+    fn create_test_schema() -> Schema {
+        Schema::new(vec![
+            Field::new("string_col", DataType::Utf8, false),
+            Field::new("int_col", DataType::Int64, false),
+        ])
+    }
+
+    #[test]
+    fn test_schemable_filter_try_from() -> Result<()> {
+        let schema = create_test_schema();
+
+        // Test string column filter creation
+        let string_filter = Filter {
+            field_name: "string_col".to_string(),
+            operator: ExprOperator::Eq,
+            field_value: "test_value".to_string(),
+        };
+
+        let schemable = SchemableFilter::try_from((string_filter, &schema))?;
+        assert_eq!(schemable.field.name(), "string_col");
+        assert_eq!(schemable.field.data_type(), &DataType::Utf8);
+        assert_eq!(schemable.operator, ExprOperator::Eq);
+
+        // Test integer column filter creation
+        let int_filter = Filter {
+            field_name: "int_col".to_string(),
+            operator: ExprOperator::Gt,
+            field_value: "42".to_string(),
+        };
+
+        let schemable = SchemableFilter::try_from((int_filter, &schema))?;
+        assert_eq!(schemable.field.name(), "int_col");
+        assert_eq!(schemable.field.data_type(), &DataType::Int64);
+        assert_eq!(schemable.operator, ExprOperator::Gt);
+
+        // Test error case - non-existent column
+        let invalid_filter = Filter {
+            field_name: "non_existent".to_string(),
+            operator: ExprOperator::Eq,
+            field_value: "value".to_string(),
+        };
+
+        assert!(SchemableFilter::try_from((invalid_filter, &schema)).is_err());
+
+        Ok(())
+    }
+
+    #[test]
+    fn test_schemable_filter_cast_value() -> Result<()> {
+        // Test casting to string
+        let string_value = SchemableFilter::cast_value(&["test"], 
&DataType::Utf8)?;
+        assert_eq!(string_value.get().0.len(), 1);
+
+        // Test casting to integer
+        let int_value = SchemableFilter::cast_value(&["42"], 
&DataType::Int64)?;
+        assert_eq!(int_value.get().0.len(), 1);
+
+        // Test invalid integer cast
+        let result = SchemableFilter::cast_value(&["not_a_number"], 
&DataType::Int64);
+        assert!(result.is_err());
+
+        Ok(())
+    }
+
+    #[test]
+    fn test_schemable_filter_apply_comparison() -> Result<()> {
+        let schema = create_test_schema();
+
+        // Test string equality comparison
+        let eq_filter = Filter {
+            field_name: "string_col".to_string(),
+            operator: ExprOperator::Eq,
+            field_value: "test".to_string(),
+        };
+        let schemable = SchemableFilter::try_from((eq_filter, &schema))?;
+
+        let test_array = StringArray::from(vec!["test", "other", "test"]);
+        let result = schemable.apply_comparsion(&test_array)?;
+        assert_eq!(result, BooleanArray::from(vec![true, false, true]));
+
+        // Test integer greater than comparison
+        let gt_filter = Filter {
+            field_name: "int_col".to_string(),
+            operator: ExprOperator::Gt,
+            field_value: "50".to_string(),
+        };
+        let schemable = SchemableFilter::try_from((gt_filter, &schema))?;
+
+        let test_array = Int64Array::from(vec![40, 50, 60]);
+        let result = schemable.apply_comparsion(&test_array)?;
+        assert_eq!(result, BooleanArray::from(vec![false, false, true]));
+
+        Ok(())
+    }
+
+    #[test]
+    fn test_schemable_filter_all_operators() -> Result<()> {
+        let schema = create_test_schema();
+        let test_array = Int64Array::from(vec![40, 50, 60]);
+
+        let test_cases = vec![
+            (ExprOperator::Eq, "50", vec![false, true, false]),
+            (ExprOperator::Ne, "50", vec![true, false, true]),
+            (ExprOperator::Lt, "50", vec![true, false, false]),
+            (ExprOperator::Lte, "50", vec![true, true, false]),
+            (ExprOperator::Gt, "50", vec![false, false, true]),
+            (ExprOperator::Gte, "50", vec![false, true, true]),
+        ];
+
+        for (operator, value, expected) in test_cases {
+            let filter = Filter {
+                field_name: "int_col".to_string(),
+                operator,
+                field_value: value.to_string(),
+            };
+
+            let schemable = SchemableFilter::try_from((filter, &schema))?;
+            let result = schemable.apply_comparsion(&test_array)?;
+            assert_eq!(
+                result,
+                BooleanArray::from(expected),
+                "Failed for operator {:?} with value {}",
+                operator,
+                value
+            );
+        }
+
+        Ok(())
+    }
 }
diff --git a/crates/core/src/file_group/reader.rs 
b/crates/core/src/file_group/reader.rs
index 165ed9c..a7884d0 100644
--- a/crates/core/src/file_group/reader.rs
+++ b/crates/core/src/file_group/reader.rs
@@ -20,22 +20,47 @@ use crate::config::table::HudiTableConfig;
 use crate::config::util::split_hudi_options_from_others;
 use crate::config::HudiConfigs;
 use crate::error::CoreError::ReadFileSliceError;
+use crate::expr::filter::{Filter, SchemableFilter};
 use crate::file_group::FileSlice;
 use crate::storage::Storage;
 use crate::Result;
-use arrow_array::RecordBatch;
+use arrow::compute::and;
+use arrow_array::{BooleanArray, RecordBatch};
+use arrow_schema::Schema;
 use futures::TryFutureExt;
 use std::sync::Arc;
 
+use arrow::compute::filter_record_batch;
+
 /// File group reader handles all read operations against a file group.
 #[derive(Clone, Debug)]
 pub struct FileGroupReader {
     storage: Arc<Storage>,
+    and_filters: Vec<SchemableFilter>,
 }
 
 impl FileGroupReader {
     pub fn new(storage: Arc<Storage>) -> Self {
-        Self { storage }
+        Self {
+            storage,
+            and_filters: Vec::new(),
+        }
+    }
+
+    pub fn new_with_filters(
+        storage: Arc<Storage>,
+        and_filters: &[Filter],
+        schema: &Schema,
+    ) -> Result<Self> {
+        let and_filters = and_filters
+            .iter()
+            .map(|filter| SchemableFilter::try_from((filter.clone(), schema)))
+            .collect::<Result<Vec<SchemableFilter>>>()?;
+
+        Ok(Self {
+            storage,
+            and_filters,
+        })
     }
 
     pub fn new_with_options<I, K, V>(base_uri: &str, options: I) -> 
Result<Self>
@@ -53,22 +78,43 @@ impl FileGroupReader {
         let hudi_configs = Arc::new(HudiConfigs::new(hudi_opts));
 
         let storage = Storage::new(Arc::new(others), hudi_configs)?;
-        Ok(Self { storage })
+        Ok(Self {
+            storage,
+            and_filters: Vec::new(),
+        })
+    }
+
+    fn create_boolean_array_mask(&self, records: &RecordBatch) -> 
Result<BooleanArray> {
+        let mut mask = BooleanArray::from(vec![true; records.num_rows()]);
+        for filter in &self.and_filters {
+            let col_name = filter.field.name().as_str();
+            let col_values = records
+                .column_by_name(col_name)
+                .ok_or_else(|| ReadFileSliceError(format!("Column {col_name} 
not found")))?;
+
+            let comparison = filter.apply_comparsion(col_values)?;
+            mask = and(&mask, &comparison)?;
+        }
+        Ok(mask)
     }
 
     pub async fn read_file_slice_by_base_file_path(
         &self,
         relative_path: &str,
     ) -> Result<RecordBatch> {
-        self.storage
+        let records: RecordBatch = self
+            .storage
             .get_parquet_file_data(relative_path)
-            .map_err(|e| {
-                ReadFileSliceError(
-                    format!("Failed to read file slice at path '{}'", 
relative_path),
-                    e,
-                )
-            })
-            .await
+            .map_err(|e| ReadFileSliceError(format!("Failed to read path 
{relative_path}: {e:?}")))
+            .await?;
+
+        if self.and_filters.is_empty() {
+            return Ok(records);
+        }
+
+        let mask = self.create_boolean_array_mask(&records)?;
+        filter_record_batch(&records, &mask)
+            .map_err(|e| ReadFileSliceError(format!("Failed to filter records: 
{e:?}")))
     }
 
     pub async fn read_file_slice(&self, file_slice: &FileSlice) -> 
Result<RecordBatch> {
@@ -80,6 +126,12 @@ impl FileGroupReader {
 #[cfg(test)]
 mod tests {
     use super::*;
+    use crate::error::CoreError;
+    use crate::expr::filter::FilterField;
+    use arrow::array::{ArrayRef, Int64Array, StringArray};
+    use arrow::record_batch::RecordBatch;
+    use arrow_schema::{DataType, Field, Schema};
+    use std::sync::Arc;
     use url::Url;
 
     #[test]
@@ -90,6 +142,41 @@ mod tests {
         assert!(Arc::ptr_eq(&fg_reader.storage, &storage));
     }
 
+    fn create_test_schema() -> Schema {
+        Schema::new(vec![
+            Field::new("_hoodie_commit_time", DataType::Utf8, false),
+            Field::new("name", DataType::Utf8, false),
+            Field::new("age", DataType::Int64, false),
+        ])
+    }
+
+    #[tokio::test]
+    async fn test_new_with_filters() -> Result<()> {
+        let base_url = Url::parse("file:///tmp/hudi_data").unwrap();
+        let storage = Storage::new_with_base_url(base_url)?;
+        let schema = create_test_schema();
+
+        // Test case 1: Empty filters
+        let reader = FileGroupReader::new_with_filters(storage.clone(), &[], 
&schema)?;
+        assert!(reader.and_filters.is_empty());
+
+        // Test case 2: Multiple filters
+        let filters = vec![
+            FilterField::new("_hoodie_commit_time").gt("0"),
+            FilterField::new("age").gte("18"),
+        ];
+        let reader = FileGroupReader::new_with_filters(storage.clone(), 
&filters, &schema)?;
+        assert_eq!(reader.and_filters.len(), 2);
+
+        // Test case 3: Invalid field name should error
+        let invalid_filters = 
vec![FilterField::new("non_existent_field").eq("value")];
+        assert!(
+            FileGroupReader::new_with_filters(storage.clone(), 
&invalid_filters, &schema).is_err()
+        );
+
+        Ok(())
+    }
+
     #[test]
     fn test_new_with_options() -> Result<()> {
         let options = vec![("key1", "value1"), ("key2", "value2")];
@@ -111,6 +198,62 @@ mod tests {
         let result = reader
             .read_file_slice_by_base_file_path("non_existent_file")
             .await;
-        assert!(matches!(result.unwrap_err(), ReadFileSliceError(_, _)));
+        assert!(matches!(result.unwrap_err(), ReadFileSliceError(_)));
+    }
+
+    fn create_test_record_batch() -> Result<RecordBatch> {
+        let schema = Arc::new(create_test_schema());
+
+        let commit_times: ArrayRef = Arc::new(StringArray::from(vec!["1", "2", 
"3", "4", "5"]));
+        let names: ArrayRef = Arc::new(StringArray::from(vec![
+            "Alice", "Bob", "Charlie", "David", "Eve",
+        ]));
+        let ages: ArrayRef = Arc::new(Int64Array::from(vec![25, 30, 35, 40, 
45]));
+
+        RecordBatch::try_new(schema, vec![commit_times, names, 
ages]).map_err(CoreError::ArrowError)
+    }
+
+    #[test]
+    fn test_create_boolean_array_mask() -> Result<()> {
+        let storage =
+            
Storage::new_with_base_url(Url::parse("file:///non-existent-path/table").unwrap())?;
+        let schema = create_test_schema();
+        let records = create_test_record_batch()?;
+
+        // Test case 1: No filters
+        let reader = FileGroupReader::new_with_filters(storage.clone(), &[], 
&schema)?;
+        let mask = reader.create_boolean_array_mask(&records)?;
+        assert_eq!(mask, BooleanArray::from(vec![true; 5]));
+
+        // Test case 2: Single filter on commit time
+        let filters = vec![FilterField::new("_hoodie_commit_time").gt("2")];
+        let reader = FileGroupReader::new_with_filters(storage.clone(), 
&filters, &schema)?;
+        let mask = reader.create_boolean_array_mask(&records)?;
+        assert_eq!(
+            mask,
+            BooleanArray::from(vec![false, false, true, true, true]),
+            "Expected only records with commit_time > '2'"
+        );
+
+        // Test case 3: Multiple AND filters
+        let filters = vec![
+            FilterField::new("_hoodie_commit_time").gt("2"),
+            FilterField::new("age").lt("40"),
+        ];
+        let reader = FileGroupReader::new_with_filters(storage.clone(), 
&filters, &schema)?;
+        let mask = reader.create_boolean_array_mask(&records)?;
+        assert_eq!(
+            mask,
+            BooleanArray::from(vec![false, false, true, false, false]),
+            "Expected only record with commit_time > '2' AND age < 40"
+        );
+
+        // Test case 4: Filter resulting in all false
+        let filters = vec![FilterField::new("age").gt("100")];
+        let reader = FileGroupReader::new_with_filters(storage.clone(), 
&filters, &schema)?;
+        let mask = reader.create_boolean_array_mask(&records)?;
+        assert_eq!(mask, BooleanArray::from(vec![false; 5]));
+
+        Ok(())
     }
 }
diff --git a/crates/core/src/table/partition.rs 
b/crates/core/src/table/partition.rs
index b781fca..0386ad2 100644
--- a/crates/core/src/table/partition.rs
+++ b/crates/core/src/table/partition.rs
@@ -20,11 +20,9 @@ use crate::config::table::HudiTableConfig;
 use crate::config::HudiConfigs;
 use crate::error::CoreError::InvalidPartitionPath;
 use crate::expr::filter::{Filter, SchemableFilter};
-use crate::expr::ExprOperator;
 use crate::Result;
 
 use arrow_array::{ArrayRef, Scalar};
-use arrow_ord::cmp::{eq, gt, gt_eq, lt, lt_eq, neq};
 use arrow_schema::Schema;
 
 use std::collections::HashMap;
@@ -90,16 +88,7 @@ impl PartitionPruner {
         self.and_filters.iter().all(|filter| {
             match segments.get(filter.field.name()) {
                 Some(segment_value) => {
-                    let comparison_result = match filter.operator {
-                        ExprOperator::Eq => eq(segment_value, &filter.value),
-                        ExprOperator::Ne => neq(segment_value, &filter.value),
-                        ExprOperator::Lt => lt(segment_value, &filter.value),
-                        ExprOperator::Lte => lt_eq(segment_value, 
&filter.value),
-                        ExprOperator::Gt => gt(segment_value, &filter.value),
-                        ExprOperator::Gte => gt_eq(segment_value, 
&filter.value),
-                    };
-
-                    match comparison_result {
+                    match filter.apply_comparsion(segment_value) {
                         Ok(scalar) => scalar.value(0),
                         Err(_) => true, // Include the partition when 
comparison error occurs
                     }
@@ -161,6 +150,7 @@ mod tests {
     use crate::config::table::HudiTableConfig::{
         IsHiveStylePartitioning, IsPartitionPathUrlencoded,
     };
+    use crate::expr::ExprOperator;
 
     use arrow::datatypes::{DataType, Field, Schema};
     use arrow_array::Date32Array;

Reply via email to