This is an automated email from the ASF dual-hosted git repository.

xushiyan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/hudi-rs.git


The following commit(s) were added to refs/heads/main by this push:
     new 8fcdf5b  feat: add partition filter arg in Python APIs (#153)
8fcdf5b is described below

commit 8fcdf5b8c9060c81b987c4b7730fe06b5c35ff13
Author: Shiyan Xu <[email protected]>
AuthorDate: Mon Oct 7 11:15:32 2024 -1000

    feat: add partition filter arg in Python APIs (#153)
---
 python/hudi/_internal.pyi       | 12 ++++++----
 python/src/internal.rs          | 49 ++++++++++++++++++++++++++++++++---------
 python/tests/test_table_read.py | 24 ++++++++++++++++++++
 3 files changed, 71 insertions(+), 14 deletions(-)

diff --git a/python/hudi/_internal.pyi b/python/hudi/_internal.pyi
index acf2e16..bc6023b 100644
--- a/python/hudi/_internal.pyi
+++ b/python/hudi/_internal.pyi
@@ -41,8 +41,12 @@ class HudiTable:
         options: Optional[Dict[str, str]] = None,
     ): ...
     def get_schema(self) -> "pyarrow.Schema": ...
-    def split_file_slices(self, n: int) -> List[List[HudiFileSlice]]: ...
-    def get_file_slices(self) -> List[HudiFileSlice]: ...
+    def get_partition_schema(self) -> "pyarrow.Schema": ...
+    def split_file_slices(
+        self, n: int, filters: Optional[List[str]]
+    ) -> List[List[HudiFileSlice]]: ...
+    def get_file_slices(self, filters: Optional[List[str]]) -> 
List[HudiFileSlice]: ...
     def read_file_slice(self, base_file_relative_path: str) -> 
pyarrow.RecordBatch: ...
-    def read_snapshot(self) -> List["pyarrow.RecordBatch"]: ...
-    def read_snapshot_as_of(self, timestamp: str) -> 
List["pyarrow.RecordBatch"]: ...
+    def read_snapshot(
+        self, filters: Optional[List[str]]
+    ) -> List["pyarrow.RecordBatch"]: ...
diff --git a/python/src/internal.rs b/python/src/internal.rs
index 53ea580..70b9f56 100644
--- a/python/src/internal.rs
+++ b/python/src/internal.rs
@@ -85,6 +85,12 @@ fn convert_file_slice(f: &FileSlice) -> HudiFileSlice {
     }
 }
 
+macro_rules! vec_string_to_slice {
+    ($vec:expr) => {
+        &$vec.iter().map(AsRef::as_ref).collect::<Vec<_>>()
+    };
+}
+
 #[cfg(not(tarpaulin))]
 #[pyclass]
 pub struct HudiTable {
@@ -108,10 +114,23 @@ impl HudiTable {
         rt().block_on(self._table.get_schema())?.to_pyarrow(py)
     }
 
-    fn split_file_slices(&self, n: usize, py: Python) -> 
PyResult<Vec<Vec<HudiFileSlice>>> {
-        // TODO: support passing filters
+    fn get_partition_schema(&self, py: Python) -> PyResult<PyObject> {
+        rt().block_on(self._table.get_partition_schema())?
+            .to_pyarrow(py)
+    }
+
+    #[pyo3(signature = (n, filters=None))]
+    fn split_file_slices(
+        &self,
+        n: usize,
+        filters: Option<Vec<String>>,
+        py: Python,
+    ) -> PyResult<Vec<Vec<HudiFileSlice>>> {
         py.allow_threads(|| {
-            let file_slices = rt().block_on(self._table.split_file_slices(n, 
&[]))?;
+            let file_slices = rt().block_on(
+                self._table
+                    .split_file_slices(n, 
vec_string_to_slice!(filters.unwrap_or_default())),
+            )?;
             Ok(file_slices
                 .iter()
                 .map(|inner_vec| 
inner_vec.iter().map(convert_file_slice).collect())
@@ -119,10 +138,17 @@ impl HudiTable {
         })
     }
 
-    fn get_file_slices(&self, py: Python) -> PyResult<Vec<HudiFileSlice>> {
-        // TODO: support passing filters
+    #[pyo3(signature = (filters=None))]
+    fn get_file_slices(
+        &self,
+        filters: Option<Vec<String>>,
+        py: Python,
+    ) -> PyResult<Vec<HudiFileSlice>> {
         py.allow_threads(|| {
-            let file_slices = rt().block_on(self._table.get_file_slices(&[]))?;
+            let file_slices = rt().block_on(
+                self._table
+                    
.get_file_slices(vec_string_to_slice!(filters.unwrap_or_default())),
+            )?;
             Ok(file_slices.iter().map(convert_file_slice).collect())
         })
     }
@@ -132,10 +158,13 @@ impl HudiTable {
             .to_pyarrow(py)
     }
 
-    fn read_snapshot(&self, py: Python) -> PyResult<PyObject> {
-        // TODO: support passing filters
-        rt().block_on(self._table.read_snapshot(&[]))?
-            .to_pyarrow(py)
+    #[pyo3(signature = (filters=None))]
+    fn read_snapshot(&self, filters: Option<Vec<String>>, py: Python) -> 
PyResult<PyObject> {
+        rt().block_on(
+            self._table
+                
.read_snapshot(vec_string_to_slice!(filters.unwrap_or_default())),
+        )?
+        .to_pyarrow(py)
     }
 }
 
diff --git a/python/tests/test_table_read.py b/python/tests/test_table_read.py
index 6a517e5..2d1fd74 100644
--- a/python/tests/test_table_read.py
+++ b/python/tests/test_table_read.py
@@ -47,6 +47,7 @@ def test_sample_table(get_sample_table):
         "fare",
         "city",
     ]
+    assert table.get_partition_schema().names == ["city"]
 
     file_slices = table.get_file_slices()
     assert len(file_slices) == 5
@@ -109,6 +110,29 @@ def test_sample_table(get_sample_table):
         },
     ]
 
+    batches = table.read_snapshot(["city = san_francisco"])
+    t = pa.Table.from_batches(batches).select([0, 5, 6, 9]).sort_by("ts")
+    assert t.to_pylist() == [
+        {
+            "_hoodie_commit_time": "20240402144910683",
+            "ts": 1695046462179,
+            "uuid": "9909a8b1-2d15-4d3d-8ec9-efc48c536a00",
+            "fare": 339.0,
+        },
+        {
+            "_hoodie_commit_time": "20240402123035233",
+            "ts": 1695091554788,
+            "uuid": "e96c4396-3fad-413a-a942-4cb36106d721",
+            "fare": 27.7,
+        },
+        {
+            "_hoodie_commit_time": "20240402123035233",
+            "ts": 1695159649087,
+            "uuid": "334e26e9-8355-45cc-97c6-c31daf0df330",
+            "fare": 19.1,
+        },
+    ]
+
     table = HudiTable(table_path, {"hoodie.read.as.of.timestamp": 
"20240402123035233"})
     batches = table.read_snapshot()
     t = pa.Table.from_batches(batches).select([0, 5, 6, 9]).sort_by("ts")

Reply via email to