This is an automated email from the ASF dual-hosted git repository.

xushiyan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/hudi-rs.git


The following commit(s) were added to refs/heads/main by this push:
     new e8fde26  feat: implement datafusion API using ParquetExec (#35)
e8fde26 is described below

commit e8fde26df8cdd5355aacce4232138222ce00baf4
Author: Shiyan Xu <[email protected]>
AuthorDate: Wed Jul 3 23:59:11 2024 -0500

    feat: implement datafusion API using ParquetExec (#35)
    
    - upgrade arrow from `50` to `52.0.0`
    - upgrade datafusion `35` to `39.0.0`
    - leverage `ParquetExec` for implementing TableProvider for Hudi in 
datafusion
    - add `hoodie.read.input.partitions` config
---
 Cargo.toml                       |  38 +++---
 crates/core/src/config/mod.rs    | 118 ++++++++++++++++++
 crates/core/src/lib.rs           |   3 +-
 crates/core/src/storage/mod.rs   |   8 +-
 crates/core/src/storage/utils.rs |  47 ++++++-
 crates/core/src/table/mod.rs     |  17 ++-
 crates/datafusion/Cargo.toml     |   8 +-
 crates/datafusion/src/lib.rs     | 261 ++++++++++++++++++---------------------
 python/Cargo.toml                |   2 +-
 python/hudi/_internal.pyi        |   3 +-
 python/hudi/_utils.py            |  23 ----
 python/hudi/table.py             |  10 +-
 python/src/lib.rs                |  12 +-
 python/tests/test_table_read.py  |   4 +-
 14 files changed, 344 insertions(+), 210 deletions(-)

diff --git a/Cargo.toml b/Cargo.toml
index 1b66057..82f0383 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -30,27 +30,27 @@ rust-version = "1.75.0"
 
 [workspace.dependencies]
 # arrow
-arrow = { version = "50", features = ["pyarrow"] }
-arrow-arith = { version = "50" }
-arrow-array = { version = "50", features = ["chrono-tz"] }
-arrow-buffer = { version = "50" }
-arrow-cast = { version = "50" }
-arrow-ipc = { version = "50" }
-arrow-json = { version = "50" }
-arrow-ord = { version = "50" }
-arrow-row = { version = "50" }
-arrow-schema = { version = "50" }
-arrow-select = { version = "50" }
-object_store = { version = "0.9.1" }
-parquet = { version = "50" }
+arrow = { version = "52.0.0", features = ["pyarrow"]}
+arrow-arith = { version = "52.0.0" }
+arrow-array = { version = "52.0.0", features = ["chrono-tz"] }
+arrow-buffer = { version = "52.0.0" }
+arrow-cast = { version = "52.0.0" }
+arrow-ipc = { version = "52.0.0" }
+arrow-json = { version = "52.0.0" }
+arrow-ord = { version = "52.0.0" }
+arrow-row = { version = "52.0.0" }
+arrow-schema = { version = "52.0.0" }
+arrow-select = { version = "52.0.0" }
+object_store = { version = "0.10.1" }
+parquet = { version = "52.0.0" }
 
 # datafusion
-datafusion = { version = "35" }
-datafusion-expr = { version = "35" }
-datafusion-common = { version = "35" }
-datafusion-proto = { version = "35" }
-datafusion-sql = { version = "35" }
-datafusion-physical-expr = { version = "35" }
+datafusion = { version = "39.0.0" }
+datafusion-expr = { version = "39.0.0" }
+datafusion-common = { version = "39.0.0" }
+datafusion-proto = { version = "39.0.0" }
+datafusion-sql = { version = "39.0.0" }
+datafusion-physical-expr = { version = "39.0.0" }
 
 # serde
 serde = { version = "1.0.203", features = ["derive"] }
diff --git a/crates/core/src/config/mod.rs b/crates/core/src/config/mod.rs
new file mode 100644
index 0000000..3322df3
--- /dev/null
+++ b/crates/core/src/config/mod.rs
@@ -0,0 +1,118 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+use std::collections::HashMap;
+
+use anyhow::{anyhow, Context, Result};
+
+pub trait OptionsParser {
+    type Output;
+
+    fn parse_value(&self, options: &HashMap<String, String>) -> 
Result<Self::Output>;
+
+    fn parse_value_or_default(&self, options: &HashMap<String, String>) -> 
Self::Output;
+}
+
+#[derive(Clone, Debug, PartialEq, Eq, Hash)]
+pub enum HudiConfig {
+    ReadInputPartitions,
+}
+
+#[derive(Debug)]
+pub enum HudiConfigValue {
+    Integer(isize),
+}
+
+impl HudiConfigValue {
+    pub fn cast<T: 'static + TryFrom<isize> + TryFrom<usize> + 
std::fmt::Debug>(&self) -> T {
+        match self {
+            HudiConfigValue::Integer(value) => 
T::try_from(*value).unwrap_or_else(|_| {
+                panic!("Failed to convert isize to {}", 
std::any::type_name::<T>())
+            }),
+        }
+    }
+}
+
+impl HudiConfig {
+    fn default_value(&self) -> Option<HudiConfigValue> {
+        match self {
+            Self::ReadInputPartitions => Some(HudiConfigValue::Integer(0)),
+        }
+    }
+}
+
+impl AsRef<str> for HudiConfig {
+    fn as_ref(&self) -> &str {
+        match self {
+            Self::ReadInputPartitions => "hoodie.read.input.partitions",
+        }
+    }
+}
+
+impl OptionsParser for HudiConfig {
+    type Output = HudiConfigValue;
+
+    fn parse_value(&self, options: &HashMap<String, String>) -> 
Result<Self::Output> {
+        match self {
+            HudiConfig::ReadInputPartitions => 
options.get(self.as_ref()).map_or_else(
+                || Err(anyhow!("Config '{}' not found", self.as_ref())),
+                |v| {
+                    v.parse::<isize>()
+                        .map(HudiConfigValue::Integer)
+                        .with_context(|| {
+                            format!("Failed to parse '{}' for config '{}'", v, 
self.as_ref())
+                        })
+                },
+            ),
+        }
+    }
+
+    fn parse_value_or_default(&self, options: &HashMap<String, String>) -> 
Self::Output {
+        self.parse_value(options).unwrap_or_else(|_| {
+            self.default_value()
+                .unwrap_or_else(|| panic!("No default value for config '{}'", 
self.as_ref()))
+        })
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use crate::config::HudiConfig::ReadInputPartitions;
+    use crate::config::OptionsParser;
+    use std::collections::HashMap;
+
+    #[test]
+    fn parse_invalid_config_value() {
+        let options =
+            HashMap::from([(ReadInputPartitions.as_ref().to_string(), 
"foo".to_string())]);
+        let value = ReadInputPartitions.parse_value(&options);
+        assert_eq!(
+            value.err().unwrap().to_string(),
+            format!(
+                "Failed to parse 'foo' for config '{}'",
+                ReadInputPartitions.as_ref()
+            )
+        );
+        assert_eq!(
+            ReadInputPartitions
+                .parse_value_or_default(&options)
+                .cast::<isize>(),
+            0
+        );
+    }
+}
diff --git a/crates/core/src/lib.rs b/crates/core/src/lib.rs
index 533d0fb..80d778f 100644
--- a/crates/core/src/lib.rs
+++ b/crates/core/src/lib.rs
@@ -22,7 +22,8 @@ use crate::table::Table;
 pub mod file_group;
 pub mod table;
 pub type HudiTable = Table;
-mod storage;
+pub mod config;
+pub mod storage;
 
 pub fn crate_version() -> &'static str {
     env!("CARGO_PKG_VERSION")
diff --git a/crates/core/src/storage/mod.rs b/crates/core/src/storage/mod.rs
index 43dd0e7..98e1b2d 100644
--- a/crates/core/src/storage/mod.rs
+++ b/crates/core/src/storage/mod.rs
@@ -37,9 +37,9 @@ use url::Url;
 use crate::storage::file_info::FileInfo;
 use crate::storage::utils::join_url_segments;
 
-pub(crate) mod file_info;
-pub(crate) mod file_stats;
-pub(crate) mod utils;
+pub mod file_info;
+pub mod file_stats;
+pub mod utils;
 
 #[allow(dead_code)]
 #[derive(Clone, Debug)]
@@ -351,7 +351,7 @@ mod tests {
         assert_eq!(file_info.name, "a.parquet");
         assert_eq!(
             file_info.uri,
-            storage.base_url.join("a.parquet").unwrap().to_string()
+            storage.base_url.join("a.parquet").unwrap().as_ref()
         );
         assert_eq!(file_info.size, 866);
     }
diff --git a/crates/core/src/storage/utils.rs b/crates/core/src/storage/utils.rs
index d1f8c4a..053366e 100644
--- a/crates/core/src/storage/utils.rs
+++ b/crates/core/src/storage/utils.rs
@@ -17,7 +17,8 @@
  * under the License.
  */
 
-use std::path::Path;
+use std::path::{Path, PathBuf};
+use std::str::FromStr;
 
 use anyhow::{anyhow, Result};
 use url::{ParseError, Url};
@@ -40,6 +41,24 @@ pub fn split_filename(filename: &str) -> Result<(String, 
String)> {
     Ok((stem, extension))
 }
 
+pub fn parse_uri(uri: &str) -> Result<Url> {
+    let mut url = Url::parse(uri)
+        .or(Url::from_file_path(PathBuf::from_str(uri)?))
+        .map_err(|_| anyhow!("Failed to parse uri: {}", uri))?;
+
+    if url.path().ends_with('/') {
+        url.path_segments_mut()
+            .map_err(|_| anyhow!("Failed to parse uri: {}", uri))?
+            .pop();
+    }
+
+    Ok(url)
+}
+
+pub fn get_scheme_authority(url: &Url) -> String {
+    format!("{}://{}", url.scheme(), url.authority())
+}
+
 pub fn join_url_segments(base_url: &Url, segments: &[&str]) -> Result<Url> {
     let mut url = base_url.clone();
 
@@ -63,7 +82,31 @@ mod tests {
 
     use url::Url;
 
-    use crate::storage::utils::join_url_segments;
+    use crate::storage::utils::{join_url_segments, parse_uri};
+
+    #[test]
+    fn parse_valid_uri_in_various_forms() {
+        let urls = vec![
+            parse_uri("/foo/").unwrap(),
+            parse_uri("file:/foo/").unwrap(),
+            parse_uri("file:///foo/").unwrap(),
+            parse_uri("hdfs://foo/").unwrap(),
+            parse_uri("s3://foo").unwrap(),
+            parse_uri("s3://foo/").unwrap(),
+            parse_uri("s3a://foo/bar/").unwrap(),
+            parse_uri("gs://foo/").unwrap(),
+            parse_uri("wasb://foo/bar").unwrap(),
+            parse_uri("wasbs://foo/").unwrap(),
+        ];
+        let schemes = vec![
+            "file", "file", "file", "hdfs", "s3", "s3", "s3a", "gs", "wasb", 
"wasbs",
+        ];
+        let paths = vec![
+            "/foo", "/foo", "/foo", "/", "", "/", "/bar", "/", "/bar", "/",
+        ];
+        assert_eq!(urls.iter().map(|u| u.scheme()).collect::<Vec<_>>(), 
schemes);
+        assert_eq!(urls.iter().map(|u| u.path()).collect::<Vec<_>>(), paths);
+    }
 
     #[test]
     fn join_base_url_with_segments() {
diff --git a/crates/core/src/table/mod.rs b/crates/core/src/table/mod.rs
index d1be74f..f3c9365 100644
--- a/crates/core/src/table/mod.rs
+++ b/crates/core/src/table/mod.rs
@@ -19,7 +19,6 @@
 
 use std::collections::HashMap;
 use std::io::{BufRead, BufReader};
-use std::path::PathBuf;
 use std::str::FromStr;
 use std::sync::Arc;
 
@@ -29,6 +28,7 @@ use arrow_schema::Schema;
 use url::Url;
 
 use crate::file_group::FileSlice;
+use crate::storage::utils::parse_uri;
 use crate::storage::Storage;
 use crate::table::config::BaseFileFormat;
 use crate::table::config::{ConfigKey, TableType};
@@ -52,9 +52,7 @@ pub struct Table {
 
 impl Table {
     pub async fn new(base_uri: &str, storage_options: HashMap<String, String>) 
-> Result<Self> {
-        let base_url = Url::from_file_path(PathBuf::from(base_uri))
-            .map_err(|_| anyhow!("Failed to create table URL: {}", base_uri))?;
-        let base_url = Arc::new(base_url);
+        let base_url = Arc::new(parse_uri(base_uri)?);
         let storage_options = Arc::new(storage_options);
 
         let props = Self::load_properties(base_url.clone(), 
storage_options.clone())
@@ -114,6 +112,17 @@ impl Table {
         self.timeline.get_latest_schema().await
     }
 
+    pub async fn split_file_slices(&self, n: usize) -> 
Result<Vec<Vec<FileSlice>>> {
+        let n = std::cmp::max(1, n);
+        let file_slices = self.get_file_slices().await?;
+        let chunk_size = (file_slices.len() + n - 1) / n;
+
+        Ok(file_slices
+            .chunks(chunk_size)
+            .map(|chunk| chunk.to_vec())
+            .collect())
+    }
+
     pub async fn get_file_slices(&self) -> Result<Vec<FileSlice>> {
         if let Some(timestamp) = self.timeline.get_latest_commit_timestamp() {
             self.get_file_slices_as_of(timestamp).await
diff --git a/crates/datafusion/Cargo.toml b/crates/datafusion/Cargo.toml
index bb1f5df..dd9f3d6 100644
--- a/crates/datafusion/Cargo.toml
+++ b/crates/datafusion/Cargo.toml
@@ -56,18 +56,14 @@ serde = { workspace = true, features = ["derive"] }
 serde_json = { workspace = true }
 
 # async
+async-trait = { workspace = true }
+futures = { workspace = true }
 tokio = { workspace = true }
 
 # "stdlib"
 anyhow = { workspace = true }
 bytes = { workspace = true }
 chrono = { workspace = true, default-features = false, features = ["clock"] }
-hashbrown = "0.14.3"
 regex = { workspace = true }
 uuid = { workspace = true, features = ["serde", "v4"] }
 url = { workspace = true }
-
-# test
-tempfile = "3.10.1"
-zip-extract = "0.1.3"
-async-trait = "0.1.79"
diff --git a/crates/datafusion/src/lib.rs b/crates/datafusion/src/lib.rs
index f677247..99f1c7e 100644
--- a/crates/datafusion/src/lib.rs
+++ b/crates/datafusion/src/lib.rs
@@ -23,55 +23,45 @@ use std::fmt::Debug;
 use std::sync::Arc;
 use std::thread;
 
-use arrow_array::RecordBatch;
 use arrow_schema::SchemaRef;
 use async_trait::async_trait;
+use datafusion::datasource::listing::PartitionedFile;
+use datafusion::datasource::object_store::ObjectStoreUrl;
+use datafusion::datasource::physical_plan::parquet::ParquetExecBuilder;
+use datafusion::datasource::physical_plan::FileScanConfig;
 use datafusion::datasource::TableProvider;
 use datafusion::execution::context::SessionState;
-use datafusion::execution::{SendableRecordBatchStream, TaskContext};
-use datafusion::physical_plan::memory::MemoryStream;
-use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan};
-use datafusion_common::{project_schema, DataFusionError};
+use datafusion::physical_plan::ExecutionPlan;
+use datafusion_common::Result;
+use datafusion_common::{DFSchema, DataFusionError};
 use datafusion_expr::{Expr, TableType};
-use datafusion_physical_expr::PhysicalSortExpr;
+use datafusion_physical_expr::create_physical_expr;
+use DataFusionError::Execution;
 
+use hudi_core::config::HudiConfig::ReadInputPartitions;
+use hudi_core::config::OptionsParser;
+use hudi_core::storage::utils::{get_scheme_authority, parse_uri};
 use hudi_core::HudiTable;
 
-#[derive(Debug, Clone)]
+#[derive(Clone, Debug)]
 pub struct HudiDataSource {
     table: Arc<HudiTable>,
+    input_partitions: usize,
 }
 
 impl HudiDataSource {
-    pub async fn new(
-        base_uri: &str,
-        storage_options: HashMap<String, String>,
-    ) -> datafusion_common::Result<Self> {
-        match HudiTable::new(base_uri, storage_options).await {
-            Ok(t) => Ok(Self { table: Arc::new(t) }),
-            Err(e) => Err(DataFusionError::Execution(format!(
-                "Failed to create Hudi table: {}",
-                e
-            ))),
+    pub async fn new(base_uri: &str, options: HashMap<String, String>) -> 
Result<Self> {
+        let input_partitions = ReadInputPartitions
+            .parse_value_or_default(&options)
+            .cast::<usize>();
+        match HudiTable::new(base_uri, options).await {
+            Ok(t) => Ok(Self {
+                table: Arc::new(t),
+                input_partitions,
+            }),
+            Err(e) => Err(Execution(format!("Failed to create Hudi table: {}", 
e))),
         }
     }
-
-    pub(crate) async fn create_physical_plan(
-        &self,
-        projections: Option<&Vec<usize>>,
-        schema: SchemaRef,
-    ) -> datafusion::error::Result<Arc<dyn ExecutionPlan>> {
-        Ok(Arc::new(HudiExec::new(projections, schema, self.clone())))
-    }
-
-    async fn get_record_batches(&mut self) -> 
datafusion_common::Result<Vec<RecordBatch>> {
-        let record_batches =
-            self.table.read_snapshot().await.map_err(|e| {
-                DataFusionError::Execution(format!("Failed to read snapshot: 
{}", e))
-            })?;
-
-        Ok(record_batches)
-    }
 }
 
 #[async_trait]
@@ -95,81 +85,46 @@ impl TableProvider for HudiDataSource {
 
     async fn scan(
         &self,
-        _state: &SessionState,
+        state: &SessionState,
         projection: Option<&Vec<usize>>,
-        _filters: &[Expr],
-        _limit: Option<usize>,
-    ) -> datafusion_common::Result<Arc<dyn ExecutionPlan>> {
-        return self.create_physical_plan(projection, self.schema()).await;
-    }
-}
-
-#[derive(Debug, Clone)]
-pub struct HudiExec {
-    data_source: HudiDataSource,
-    projected_schema: SchemaRef,
-}
-
-impl HudiExec {
-    fn new(
-        projections: Option<&Vec<usize>>,
-        schema: SchemaRef,
-        data_source: HudiDataSource,
-    ) -> Self {
-        let projected_schema = project_schema(&schema, projections).unwrap();
-        Self {
-            data_source,
-            projected_schema,
+        filters: &[Expr],
+        limit: Option<usize>,
+    ) -> Result<Arc<dyn ExecutionPlan>> {
+        let file_slices = self
+            .table
+            .split_file_slices(self.input_partitions)
+            .await
+            .map_err(|e| Execution(format!("Failed to get file slices from 
Hudi table: {}", e)))?;
+        let mut parquet_file_groups: Vec<Vec<PartitionedFile>> = Vec::new();
+        for file_slice_vec in file_slices {
+            let parquet_file_group_vec = file_slice_vec
+                .iter()
+                .map(|f| {
+                    let url = parse_uri(&f.base_file.info.uri).unwrap();
+                    let size = f.base_file.info.size as u64;
+                    PartitionedFile::new(url.path(), size)
+                })
+                .collect();
+            parquet_file_groups.push(parquet_file_group_vec)
         }
-    }
-}
 
-impl DisplayAs for HudiExec {
-    fn fmt_as(&self, _t: DisplayFormatType, f: &mut std::fmt::Formatter) -> 
std::fmt::Result {
-        write!(f, "HudiExec")
-    }
-}
-
-impl ExecutionPlan for HudiExec {
-    fn as_any(&self) -> &dyn Any {
-        self
-    }
-
-    fn schema(&self) -> SchemaRef {
-        self.projected_schema.clone()
-    }
-
-    fn output_partitioning(&self) -> datafusion::physical_plan::Partitioning {
-        datafusion::physical_plan::Partitioning::UnknownPartitioning(1)
-    }
+        let url = 
ObjectStoreUrl::parse(get_scheme_authority(&self.table.base_url))?;
+        let fsc = FileScanConfig::new(url, self.schema())
+            .with_file_groups(parquet_file_groups)
+            .with_projection(projection.cloned())
+            .with_limit(limit);
 
-    fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> {
-        None
-    }
-
-    fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
-        vec![]
-    }
+        let parquet_opts = state.table_options().parquet.clone();
+        let mut exec_builder = ParquetExecBuilder::new_with_options(fsc, 
parquet_opts);
 
-    fn with_new_children(
-        self: Arc<Self>,
-        _children: Vec<Arc<dyn ExecutionPlan>>,
-    ) -> datafusion_common::Result<Arc<dyn ExecutionPlan>> {
-        Ok(self)
-    }
+        let filter = filters.iter().cloned().reduce(|acc, new| acc.and(new));
+        if let Some(expr) = filter {
+            let df_schema = DFSchema::try_from(self.schema())?;
+            let predicate = create_physical_expr(&expr, &df_schema, 
state.execution_props())?;
+            exec_builder = exec_builder.with_predicate(predicate)
+        }
 
-    fn execute(
-        &self,
-        _partition: usize,
-        _context: Arc<TaskContext>,
-    ) -> datafusion_common::Result<SendableRecordBatchStream> {
-        let mut data_source = self.data_source.clone();
-        let handle = thread::spawn(move || {
-            let rt = tokio::runtime::Runtime::new().unwrap();
-            rt.block_on(data_source.get_record_batches()).unwrap()
-        });
-        let data = handle.join().unwrap();
-        Ok(Box::pin(MemoryStream::try_new(data, self.schema(), None)?))
+        return Ok(exec_builder.build_arc());
     }
 }
 
@@ -178,64 +133,92 @@ mod tests {
     use std::collections::HashMap;
     use std::sync::Arc;
 
-    use arrow_array::{Array, Int32Array, StringArray};
-    use datafusion::dataframe::DataFrame;
+    use arrow_array::{Array, Int32Array, RecordBatch, StringArray};
     use datafusion::prelude::{SessionConfig, SessionContext};
-    use datafusion_common::ScalarValue;
+    use datafusion_common::{DataFusionError, Result, ScalarValue};
 
+    use hudi_core::config::HudiConfig;
     use hudi_tests::TestTable;
 
     use crate::HudiDataSource;
 
     #[tokio::test]
-    async fn datafusion_read_hudi_table() {
+    async fn datafusion_read_hudi_table() -> Result<(), DataFusionError> {
         let config = SessionConfig::new().set(
             "datafusion.sql_parser.enable_ident_normalization",
             ScalarValue::from(false),
         );
         let ctx = SessionContext::new_with_config(config);
         let base_url = TestTable::V6ComplexkeygenHivestyle.url();
-        let hudi = HudiDataSource::new(base_url.path(), HashMap::new())
-            .await
-            .unwrap();
-        ctx.register_table("hudi_table_complexkeygen", Arc::new(hudi))
-            .unwrap();
-        let df: DataFrame = ctx
-            .sql("SELECT * from hudi_table_complexkeygen where 
structField.field2 > 30 order by name")
-            .await.unwrap();
-        let records = df
-            .collect()
-            .await
-            .unwrap()
-            .to_vec()
-            .first()
-            .unwrap()
-            .to_owned();
-        let files: Vec<String> = records
-            .column_by_name("_hoodie_file_name")
-            .unwrap()
-            .as_any()
-            .downcast_ref::<StringArray>()
-            .unwrap()
-            .iter()
-            .map(|s| s.unwrap_or_default().to_string())
-            .collect();
+        let hudi = HudiDataSource::new(
+            base_url.as_str(),
+            HashMap::from([(
+                HudiConfig::ReadInputPartitions.as_ref().to_string(),
+                "2".to_string(),
+            )]),
+        )
+        .await?;
+        ctx.register_table("hudi_table_complexkeygen", Arc::new(hudi))?;
+        let sql = r#"
+        SELECT _hoodie_file_name, id, name, structField.field2
+        FROM hudi_table_complexkeygen WHERE id % 2 = 0
+        AND structField.field2 > 30 ORDER BY name LIMIT 10"#;
+
+        // verify plan
+        let explaining_df = ctx.sql(sql).await?.explain(false, true).unwrap();
+        let explaining_rb = explaining_df.collect().await?;
+        let explaining_rb = explaining_rb.first().unwrap();
+        let plan = get_str_column(explaining_rb, "plan").join("");
+        let plan_lines: Vec<&str> = plan.lines().map(str::trim).collect();
+        assert!(plan_lines[2].starts_with("SortExec: TopK(fetch=10)"));
+        assert!(plan_lines[3].starts_with("ProjectionExec: 
expr=[_hoodie_file_name@0 as _hoodie_file_name, id@1 as id, name@2 as name, 
get_field(structField@3, field2) as 
hudi_table_complexkeygen.structField[field2]]"));
+        assert!(plan_lines[5].starts_with(
+            "FilterExec: CAST(id@1 AS Int64) % 2 = 0 AND 
get_field(structField@3, field2) > 30"
+        ));
+        assert!(plan_lines[6].contains("input_partitions=2"));
+
+        // verify data
+        let df = ctx.sql(sql).await?;
+        let rb = df.collect().await?;
+        let rb = rb.first().unwrap();
         assert_eq!(
-            files,
-            vec![
+            get_str_column(rb, "_hoodie_file_name"),
+            &[
                 
"bb7c3a45-387f-490d-aab2-981c3f1a8ada-0_0-140-198_20240418173213674.parquet",
                 
"4668e35e-bff8-4be9-9ff2-e7fb17ecb1a7-0_1-161-224_20240418173235694.parquet"
             ]
         );
-        let ids: Vec<i32> = records
-            .column_by_name("id")
+        assert_eq!(get_i32_column(rb, "id"), &[2, 4]);
+        assert_eq!(get_str_column(rb, "name"), &["Bob", "Diana"]);
+        assert_eq!(
+            get_i32_column(rb, "hudi_table_complexkeygen.structField[field2]"),
+            &[40, 50]
+        );
+
+        Ok(())
+    }
+
+    fn get_str_column<'a>(record_batch: &'a RecordBatch, name: &str) -> 
Vec<&'a str> {
+        record_batch
+            .column_by_name(name)
+            .unwrap()
+            .as_any()
+            .downcast_ref::<StringArray>()
+            .unwrap()
+            .iter()
+            .map(|s| s.unwrap())
+            .collect::<Vec<_>>()
+    }
+
+    fn get_i32_column(record_batch: &RecordBatch, name: &str) -> Vec<i32> {
+        record_batch
+            .column_by_name(name)
             .unwrap()
             .as_any()
             .downcast_ref::<Int32Array>()
             .unwrap()
             .iter()
-            .map(|i| i.unwrap_or_default())
-            .collect();
-        assert_eq!(ids, vec![2, 4])
+            .map(|s| s.unwrap())
+            .collect::<Vec<_>>()
     }
 }
diff --git a/python/Cargo.toml b/python/Cargo.toml
index 8f1d079..4eca0d2 100644
--- a/python/Cargo.toml
+++ b/python/Cargo.toml
@@ -49,7 +49,7 @@ futures = { workspace = true }
 tokio = { workspace = true }
 
 [dependencies.pyo3]
-version = "0.20.3"
+version = "0.21.2"
 features = ["extension-module", "abi3", "abi3-py38", "anyhow"]
 
 [dependencies.hudi]
diff --git a/python/hudi/_internal.pyi b/python/hudi/_internal.pyi
index b91b492..3485316 100644
--- a/python/hudi/_internal.pyi
+++ b/python/hudi/_internal.pyi
@@ -33,7 +33,6 @@ class HudiFileSlice:
     base_file_size: int
     num_records: int
 
-    @property
     def base_file_relative_path(self) -> str: ...
 
 
@@ -47,6 +46,8 @@ class BindingHudiTable:
 
     def get_schema(self) -> "pyarrow.Schema": ...
 
+    def split_file_slices(self, n: int) -> List[List[HudiFileSlice]]: ...
+
     def get_file_slices(self) -> List[HudiFileSlice]: ...
 
     def read_file_slice(self, base_file_relative_path) -> pyarrow.RecordBatch: 
...
diff --git a/python/hudi/_utils.py b/python/hudi/_utils.py
deleted file mode 100644
index 779a14b..0000000
--- a/python/hudi/_utils.py
+++ /dev/null
@@ -1,23 +0,0 @@
-#  Licensed to the Apache Software Foundation (ASF) under one
-#  or more contributor license agreements.  See the NOTICE file
-#  distributed with this work for additional information
-#  regarding copyright ownership.  The ASF licenses this file
-#  to you under the Apache License, Version 2.0 (the
-#  "License"); you may not use this file except in compliance
-#  with the License.  You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-#  Unless required by applicable law or agreed to in writing,
-#  software distributed under the License is distributed on an
-#  "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-#  KIND, either express or implied.  See the License for the
-#  specific language governing permissions and limitations
-#  under the License.
-from typing import List, Any, Iterator
-
-
-def split_list(lst: List[Any], n: int) -> Iterator[List[Any]]:
-    split_size = (len(lst) + n - 1) // n
-    for i in range(0, len(lst), split_size):
-        yield lst[i: i + split_size]
diff --git a/python/hudi/table.py b/python/hudi/table.py
index def024d..107b665 100644
--- a/python/hudi/table.py
+++ b/python/hudi/table.py
@@ -18,12 +18,10 @@
 import os
 from dataclasses import dataclass
 from pathlib import Path
-from typing import Union, List, Iterator, Optional, Dict
+from typing import Union, List, Optional, Dict
 
 import pyarrow
-
 from hudi._internal import BindingHudiTable, HudiFileSlice
-from hudi._utils import split_list
 
 
 @dataclass(init=False)
@@ -39,10 +37,8 @@ class HudiTable:
     def get_schema(self) -> "pyarrow.Schema":
         return self._table.get_schema()
 
-    def split_file_slices(self, n: int) -> Iterator[List[HudiFileSlice]]:
-        file_slices = self.get_file_slices()
-        for split in split_list(file_slices, n):
-            yield split
+    def split_file_slices(self, n: int) -> List[List[HudiFileSlice]]:
+        return self._table.split_file_slices(n)
 
     def get_file_slices(self) -> List[HudiFileSlice]:
         return self._table.get_file_slices()
diff --git a/python/src/lib.rs b/python/src/lib.rs
index 10f39ec..8962cff 100644
--- a/python/src/lib.rs
+++ b/python/src/lib.rs
@@ -104,6 +104,16 @@ impl BindingHudiTable {
         rt().block_on(self._table.get_schema())?.to_pyarrow(py)
     }
 
+    pub fn split_file_slices(&self, n: usize, py: Python) -> 
PyResult<Vec<Vec<HudiFileSlice>>> {
+        py.allow_threads(|| {
+            let file_slices = rt().block_on(self._table.split_file_slices(n))?;
+            Ok(file_slices
+                .iter()
+                .map(|inner_vec| 
inner_vec.iter().map(convert_file_slice).collect())
+                .collect())
+        })
+    }
+
     pub fn get_file_slices(&self, py: Python) -> PyResult<Vec<HudiFileSlice>> {
         py.allow_threads(|| {
             let file_slices = rt().block_on(self._table.get_file_slices())?;
@@ -134,7 +144,7 @@ fn rust_core_version() -> &'static str {
 
 #[cfg(not(tarpaulin))]
 #[pymodule]
-fn _internal(_py: Python, m: &PyModule) -> PyResult<()> {
+fn _internal(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
     m.add("__version__", env!("CARGO_PKG_VERSION"))?;
     m.add_function(wrap_pyfunction!(rust_core_version, m)?)?;
 
diff --git a/python/tests/test_table_read.py b/python/tests/test_table_read.py
index 883958c..6370bd5 100644
--- a/python/tests/test_table_read.py
+++ b/python/tests/test_table_read.py
@@ -36,7 +36,7 @@ def test_sample_table(get_sample_table):
     assert len(file_slices) == 5
     assert set(f.commit_time for f in file_slices) == {'20240402123035233', 
'20240402144910683'}
     assert all(f.num_records == 1 for f in file_slices)
-    file_slice_paths = [f.base_file_relative_path for f in file_slices]
+    file_slice_paths = [f.base_file_relative_path() for f in file_slices]
     assert set(file_slice_paths) == 
{'chennai/68d3c349-f621-4cd8-9e8b-c6dd8eb20d08-0_4-12-0_20240402123035233.parquet',
                                      
'san_francisco/d9082ffd-2eb1-4394-aefc-deb4a61ecc57-0_1-9-0_20240402123035233.parquet',
                                      
'san_francisco/780b8586-3ad0-48ef-a6a1-d2217845ce4a-0_0-8-0_20240402123035233.parquet',
@@ -48,7 +48,7 @@ def test_sample_table(get_sample_table):
     assert t.num_rows == 1
     assert t.num_columns == 11
 
-    file_slices_gen = table.split_file_slices(2)
+    file_slices_gen = iter(table.split_file_slices(2))
     assert len(next(file_slices_gen)) == 3
     assert len(next(file_slices_gen)) == 2
 

Reply via email to